diff --git a/src/common/analyzer/rag_analyzer.cpp b/src/common/analyzer/rag_analyzer.cpp index 2fbcce4c10..c058c72b53 100644 --- a/src/common/analyzer/rag_analyzer.cpp +++ b/src/common/analyzer/rag_analyzer.cpp @@ -23,6 +23,8 @@ module; #include #include #include +#include +#include #include "string_utils.h" @@ -957,6 +959,62 @@ void RAGAnalyzer::EnglishNormalize(const Vector &tokens, Vector } } +// TODO: for test +// #ifndef INFINITY_DEBUG +// #define INFINITY_DEBUG 1 +// #endif + +#ifdef INFINITY_DEBUG +template +String TestPrintTokens(const Vector &tokens) { + std::ostringstream oss; + for (std::size_t i = 0; i < tokens.size(); ++i) { + oss << (i ? " #" : "#"); + oss << tokens[i]; + oss << "#"; + } + return std::move(oss).str(); +} + +inline void CheckDP(const RAGAnalyzer *this_ptr, + const std::string_view input_str, + const Vector &dfs_tokens, + const double dfs_score, + const auto t0, + const auto t1) { + const auto [dp_vec, dp_score] = this_ptr->GetBestTokens(input_str); + const auto t2 = std::chrono::high_resolution_clock::now(); + const auto dfs_duration = std::chrono::duration_cast>(t1 - t0); + const auto dp_duration = std::chrono::duration_cast>(t2 - t1); + auto print_1 = [](const bool b) { + return b ? "✅✅✅" : "❌❌❌"; + }; + auto print_2 = [](const bool b) { + return b ? "" : " not"; + }; + const auto dp_faster = dp_duration < dfs_duration; + std::cerr << "\n!!! " << print_1(dp_faster) << "\nDFS duration: " << dfs_duration << " \nDP duration: " << dp_duration; + const auto b_score_eq = dp_score == dfs_score; + std::cerr << std::format("\n{} DFS and DP score{} equal:\nDFS: {}\nDP : {}\n", print_1(b_score_eq), print_2(b_score_eq), dfs_score, dp_score); + bool vec_equal = true; + if (dp_vec.size() != dfs_tokens.size()) { + vec_equal = false; + } else { + for (std::size_t k = 0; k < dp_vec.size(); ++k) { + if (dp_vec[k] != dfs_tokens[k]) { + vec_equal = false; + break; + } + } + } + std::cerr << std::format("{} DFS and DP result{} equal:\nDFS: {}\nDP : {}\n", + print_1(vec_equal), + print_2(vec_equal), + TestPrintTokens(dfs_tokens), + TestPrintTokens(dp_vec)); +} +#endif + void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) const { auto [tks, s] = MaxForward(L); auto [tks1, s1] = MaxBackward(L); @@ -992,7 +1050,15 @@ void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) const { Vector>> token_list; Vector best_tokens; double max_score = -100.0F; - DFS(Join(tks, _j, j, ""), 0, pre_tokens, token_list, best_tokens, max_score, false); + const auto str_for_dfs = Join(tks, _j, j, ""); +#ifdef INFINITY_DEBUG + const auto t0 = std::chrono::high_resolution_clock::now(); +#endif + DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false); +#ifdef INFINITY_DEBUG + const auto t1 = std::chrono::high_resolution_clock::now(); + CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); +#endif res.push_back(Join(best_tokens, 0)); same = 1; @@ -1009,7 +1075,15 @@ void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) const { Vector>> token_list; Vector best_tokens; double max_score = -100.0F; - DFS(Join(tks, _j, tks.size(), ""), 0, pre_tokens, token_list, best_tokens, max_score, false); + const auto str_for_dfs = Join(tks, _j, tks.size(), ""); +#ifdef INFINITY_DEBUG + const auto t0 = std::chrono::high_resolution_clock::now(); +#endif + DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false); +#ifdef INFINITY_DEBUG + const auto t1 = std::chrono::high_resolution_clock::now(); + CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); +#endif res.push_back(Join(best_tokens, 0)); } @@ -1048,7 +1122,15 @@ void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) const { Vector>> token_list; Vector best_tokens; double max_score = -100.0F; - DFS(Join(tks, s, e < tks.size() ? e + 1 : e, ""), 0, pre_tokens, token_list, best_tokens, max_score, false); + const auto str_for_dfs = Join(tks, s, e < tks.size() ? e + 1 : e, ""); +#ifdef INFINITY_DEBUG + const auto t0 = std::chrono::high_resolution_clock::now(); +#endif + DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false); +#ifdef INFINITY_DEBUG + const auto t1 = std::chrono::high_resolution_clock::now(); + CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); +#endif // Vector, double>> sorted_tokens; // SortTokens(token_list, sorted_tokens); // res.push_back(Join(sorted_tokens[0].first, 0)); @@ -1146,7 +1228,7 @@ String RAGAnalyzer::Tokenize(const String &line) { res.push_back(L); continue; } - +#if 1 if (length > MAX_SENTENCE_LEN) { Vector sublines; SplitLongText(L, length, sublines); @@ -1154,7 +1236,7 @@ String RAGAnalyzer::Tokenize(const String &line) { TokenizeInner(res, l); } } else - +#endif TokenizeInner(res, L); } Vector normalize_res;