Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update rag_analyzer.cpp #2273

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 73 additions & 80 deletions src/common/analyzer/rag_analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -821,78 +821,64 @@ int RAGAnalyzer::DFS(const String &chars,
return DFS(chars, s + 1, pre_tokens, token_list, best_tokens, max_score, memo_all);
}

struct TokensList {
const TokensList *prev = nullptr;
std::string_view token = {};
};

struct BestTokenCandidate {
#ifdef DIVIDE_F_BY_N
u32 token_num{};
i64 score_sum{};
#else
static constexpr i64 B = 30;
TokensList tl{};
// N: token num
// L: num of tokens with length >= 2
// F: sum of freq
Pair<u32, u32> N_L{};
u32 N{};
u32 L{};
i64 F{};
auto k() const {
#ifdef DIVIDE_F_BY_N
return N;
#else
return std::make_pair(N, L);
#endif
Vector<std::string_view> tokens{};
};

template <class... Fs>
struct Overload : Fs... {
using Fs::operator()...;
}
auto v() const { return F; }
auto score() const {
#ifdef DIVIDE_F_BY_N
return static_cast<double>(B + L + F) / N;
#else
return F + (static_cast<double>(B + L) / N);
#endif
}
BestTokenCandidate update(const std::string_view new_token_sv, const i32 key_f, const u32 add_l) const {
return {{&tl, new_token_sv}, N + 1, L + add_l, F + key_f};
}
};

// explicit deduction guide
template <class... Fs>
Overload(Fs...) -> Overload<Fs...>;

struct GrowingBestTokenCandidatesTopN {
const i32 top_n{};
i32 top_n{};
Vector<BestTokenCandidate> candidates{};

explicit GrowingBestTokenCandidatesTopN(const i32 top_n) : top_n(top_n) {}

#ifdef DIVIDE_F_BY_N
void AddBestTokenCandidateTopN(const u32 tn, const i64 ss, const Vector<std::string_view> &tks_old_first, const std::string_view wait_append) {
const auto e_r_comp = Overload{[](const BestTokenCandidate &a, const u32 x) { return a.token_num < x; },
[](const u32 x, const BestTokenCandidate &a) { return x < a.token_num; }};
const auto min_comp = [](const BestTokenCandidate &a, const BestTokenCandidate &b) { return a.score_sum < b.score_sum; };
const auto [it_b, it_e] = std::equal_range(candidates.begin(), candidates.end(), tn, e_r_comp);
#else
void AddBestTokenCandidateTopN(const Pair<u32, u32> n_l,
const i64 new_f,
const Vector<std::string_view> &tks_old_first,
const std::string_view wait_append) {
const auto e_r_comp = Overload{[](const BestTokenCandidate &a, const Pair<u32, u32> x) { return a.N_L < x; },
[](const Pair<u32, u32> x, const BestTokenCandidate &a) { return x < a.N_L; }};
const auto min_comp = [](const BestTokenCandidate &a, const BestTokenCandidate &b) { return a.F < b.F; };
const auto [it_b, it_e] = std::equal_range(candidates.begin(), candidates.end(), n_l, e_r_comp);
#endif
void AddBestTokenCandidateTopN(const BestTokenCandidate &add_candidate) {
const auto [it_b, it_e] =
std::equal_range(candidates.begin(), candidates.end(), add_candidate, [](const auto &a, const auto &b) { return a.k() < b.k(); });
auto target_it = it_b;
bool do_replace = false;
if (const auto match_cnt = std::distance(it_b, it_e); match_cnt >= top_n) {
assert(match_cnt == top_n);
const auto it = std::min_element(it_b, it_e, min_comp);
#ifdef DIVIDE_F_BY_N
if (it->score_sum >= ss) {
#else
if (it->F >= new_f) {
#endif
const auto it = std::min_element(it_b, it_e, [](const auto &a, const auto &b) { return a.v() < b.v(); });
if (it->v() >= add_candidate.v()) {
return;
}
target_it = it;
do_replace = true;
}
#ifdef DIVIDE_F_BY_N
BestTokenCandidate candidate = {tn, ss};
#else
BestTokenCandidate candidate = {n_l, new_f};
#endif
candidate.tokens.reserve(tks_old_first.size() + 1);
candidate.tokens.insert(candidate.tokens.end(), tks_old_first.begin(), tks_old_first.end());
candidate.tokens.push_back(wait_append);
if (do_replace) {
*target_it = std::move(candidate);
*target_it = add_candidate;
} else {
candidates.insert(target_it, std::move(candidate));
candidates.insert(target_it, add_candidate);
}
}
};
Expand Down Expand Up @@ -925,21 +911,9 @@ Vector<Pair<Vector<std::string_view>, double>> RAGAnalyzer::GetBestTokensTopN(co
lookup_left_chars -= next_one_utf8.size();
}
auto dp_f = [&dp_vec, i, j, original_sv = std::string_view{current_utf8_ptr, growing_key.size()}](const i32 key_f, const u32 add_l) {
#ifdef DIVIDE_F_BY_N
const i32 key_score = key_f + add_l;
for (auto &target_dp = dp_vec[i + j]; const auto &[tn, ss, v] : dp_vec[i].candidates) {
target_dp.AddBestTokenCandidateTopN(tn + 1, ss + key_score, v, original_sv);
for (auto &target_dp = dp_vec[i + j]; const auto &c : dp_vec[i].candidates) {
target_dp.AddBestTokenCandidateTopN(c.update(original_sv, key_f, add_l));
}
#else
auto get_add_n_l = [add_l](Pair<u32, u32> old_n_l) {
++old_n_l.first;
old_n_l.second += add_l;
return old_n_l;
};
for (auto &target_dp = dp_vec[i + j]; const auto &[old_n_l, old_f, old_v] : dp_vec[i].candidates) {
target_dp.AddBestTokenCandidateTopN(get_add_n_l(old_n_l), old_f + key_f, old_v, original_sv);
}
#endif
};
if (const auto traverse_result = trie_->Traverse(growing_key.data(), reuse_node_pos, reuse_key_pos, growing_key.size());
traverse_result >= 0) {
Expand All @@ -964,27 +938,46 @@ Vector<Pair<Vector<std::string_view>, double>> RAGAnalyzer::GetBestTokensTopN(co
current_utf8_ptr += forward_cnt;
current_left_chars -= forward_cnt;
}
Vector<Pair<Vector<std::string_view>, double>> result;
result.reserve(n);
constexpr i64 B = 30;
#ifdef DIVIDE_F_BY_N
for (auto &[token_num, score_sum, tokens] : dp_vec.back().candidates) {
auto new_pair = std::make_pair(std::move(tokens), (static_cast<double>(B + score_sum) / token_num));
#else
for (auto &[N_L, F, tokens] : dp_vec.back().candidates) {
auto new_pair = std::make_pair(std::move(tokens), (F + (static_cast<double>(B + N_L.second) / N_L.first)));
#endif
if (result.size() < n) {
result.push_back(std::move(new_pair));
Vector<Pair<const TokensList *, double>> mid_result;
mid_result.reserve(n);
for (const auto &c : dp_vec.back().candidates) {
const auto new_pair = std::make_pair(&(c.tl), c.score());
if (mid_result.size() < n) {
mid_result.push_back(new_pair);
} else {
assert(result.size() == n);
if (new_pair.second > result.back().second) {
result.pop_back();
const auto insert_pos =
std::lower_bound(result.begin(), result.end(), new_pair, [](const auto &a, const auto &b) { return a.second > b.second; });
result.insert(insert_pos, std::move(new_pair));
assert(mid_result.size() == n);
if (new_pair.second > mid_result.back().second) {
mid_result.pop_back();
const auto insert_pos = std::lower_bound(mid_result.begin(), mid_result.end(), new_pair, [](const auto &a, const auto &b) {
return a.second > b.second;
});
mid_result.insert(insert_pos, new_pair);
}
}
}
class HelperFunc {
u32 cnt = 0;
Vector<std::string_view> result{};
void GetTokensInner(const TokensList *tl) {
if (!tl->prev) {
result.reserve(cnt);
return;
}
++cnt;
GetTokensInner(tl->prev);
result.push_back(tl->token);
}

public:
Vector<std::string_view> GetTokens(const TokensList *tl) {
GetTokensInner(tl);
return std::move(result);
}
};
Vector<Pair<Vector<std::string_view>, double>> result;
result.reserve(mid_result.size());
for (const auto [tl, score] : mid_result) {
result.emplace_back(HelperFunc{}.GetTokens(tl), score);
}
return result;
}
Expand Down