Skip to content

Commit

Permalink
Update test
Browse files Browse the repository at this point in the history
  • Loading branch information
yangzq50 committed Nov 14, 2024
1 parent 314a9ba commit eb8d90e
Showing 1 changed file with 87 additions and 5 deletions.
92 changes: 87 additions & 5 deletions src/common/analyzer/rag_analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ module;
#include <pcre2.h>
#include <re2/re2.h>
#include <sstream>
#include <iostream>
#include <chrono>

#include "string_utils.h"

Expand Down Expand Up @@ -957,6 +959,62 @@ void RAGAnalyzer::EnglishNormalize(const Vector<String> &tokens, Vector<String>
}
}

// TODO: for test
// #ifndef INFINITY_DEBUG
// #define INFINITY_DEBUG 1
// #endif

#ifdef INFINITY_DEBUG
template <typename T>
String TestPrintTokens(const Vector<T> &tokens) {
std::ostringstream oss;
for (std::size_t i = 0; i < tokens.size(); ++i) {
oss << (i ? " #" : "#");
oss << tokens[i];
oss << "#";
}
return std::move(oss).str();
}

inline void CheckDP(const RAGAnalyzer *this_ptr,
const std::string_view input_str,
const Vector<String> &dfs_tokens,
const double dfs_score,
const auto t0,
const auto t1) {
const auto [dp_vec, dp_score] = this_ptr->GetBestTokens(input_str);
const auto t2 = std::chrono::high_resolution_clock::now();
const auto dfs_duration = std::chrono::duration_cast<std::chrono::duration<float, std::milli>>(t1 - t0);
const auto dp_duration = std::chrono::duration_cast<std::chrono::duration<float, std::milli>>(t2 - t1);
auto print_1 = [](const bool b) {
return b ? "✅✅✅" : "❌❌❌";
};
auto print_2 = [](const bool b) {
return b ? "" : " not";
};
const auto dp_faster = dp_duration < dfs_duration;
std::cerr << "\n!!! " << print_1(dp_faster) << "\nDFS duration: " << dfs_duration << " \nDP duration: " << dp_duration;
const auto b_score_eq = dp_score == dfs_score;
std::cerr << std::format("\n{} DFS and DP score{} equal:\nDFS: {}\nDP : {}\n", print_1(b_score_eq), print_2(b_score_eq), dfs_score, dp_score);
bool vec_equal = true;
if (dp_vec.size() != dfs_tokens.size()) {
vec_equal = false;
} else {
for (std::size_t k = 0; k < dp_vec.size(); ++k) {
if (dp_vec[k] != dfs_tokens[k]) {
vec_equal = false;
break;
}
}
}
std::cerr << std::format("{} DFS and DP result{} equal:\nDFS: {}\nDP : {}\n",
print_1(vec_equal),
print_2(vec_equal),
TestPrintTokens(dfs_tokens),
TestPrintTokens(dp_vec));
}
#endif

void RAGAnalyzer::TokenizeInner(Vector<String> &res, const String &L) const {
auto [tks, s] = MaxForward(L);
auto [tks1, s1] = MaxBackward(L);
Expand Down Expand Up @@ -992,7 +1050,15 @@ void RAGAnalyzer::TokenizeInner(Vector<String> &res, const String &L) const {
Vector<Vector<Pair<String, int>>> token_list;
Vector<String> best_tokens;
double max_score = -100.0F;
DFS(Join(tks, _j, j, ""), 0, pre_tokens, token_list, best_tokens, max_score, false);
const auto str_for_dfs = Join(tks, _j, j, "");
#ifdef INFINITY_DEBUG
const auto t0 = std::chrono::high_resolution_clock::now();
#endif
DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false);
#ifdef INFINITY_DEBUG
const auto t1 = std::chrono::high_resolution_clock::now();
CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1);
#endif
res.push_back(Join(best_tokens, 0));

same = 1;
Expand All @@ -1009,7 +1075,15 @@ void RAGAnalyzer::TokenizeInner(Vector<String> &res, const String &L) const {
Vector<Vector<Pair<String, int>>> token_list;
Vector<String> best_tokens;
double max_score = -100.0F;
DFS(Join(tks, _j, tks.size(), ""), 0, pre_tokens, token_list, best_tokens, max_score, false);
const auto str_for_dfs = Join(tks, _j, tks.size(), "");
#ifdef INFINITY_DEBUG
const auto t0 = std::chrono::high_resolution_clock::now();
#endif
DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false);
#ifdef INFINITY_DEBUG
const auto t1 = std::chrono::high_resolution_clock::now();
CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1);
#endif
res.push_back(Join(best_tokens, 0));
}

Expand Down Expand Up @@ -1048,7 +1122,15 @@ void RAGAnalyzer::TokenizeInner(Vector<String> &res, const String &L) const {
Vector<Vector<Pair<String, int>>> token_list;
Vector<String> best_tokens;
double max_score = -100.0F;
DFS(Join(tks, s, e < tks.size() ? e + 1 : e, ""), 0, pre_tokens, token_list, best_tokens, max_score, false);
const auto str_for_dfs = Join(tks, s, e < tks.size() ? e + 1 : e, "");
#ifdef INFINITY_DEBUG
const auto t0 = std::chrono::high_resolution_clock::now();
#endif
DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false);
#ifdef INFINITY_DEBUG
const auto t1 = std::chrono::high_resolution_clock::now();
CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1);
#endif
// Vector<Pair<Vector<String>, double>> sorted_tokens;
// SortTokens(token_list, sorted_tokens);
// res.push_back(Join(sorted_tokens[0].first, 0));
Expand Down Expand Up @@ -1146,15 +1228,15 @@ String RAGAnalyzer::Tokenize(const String &line) {
res.push_back(L);
continue;
}

#if 1
if (length > MAX_SENTENCE_LEN) {
Vector<String> sublines;
SplitLongText(L, length, sublines);
for (auto &l : sublines) {
TokenizeInner(res, l);
}
} else

#endif
TokenizeInner(res, L);
}
Vector<String> normalize_res;
Expand Down

0 comments on commit eb8d90e

Please sign in to comment.