Skip to content

Commit

Permalink
Improve performance of RAG tokenizer (#2177)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

- Split long text into short ones. Further improvements are required for
smarter splitting.
- Fix deadlock of memory_indexer during offline building.

Issue link:#2159

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Performance Improvement
  • Loading branch information
yingfeng authored Nov 5, 2024
1 parent 98091cc commit bcde525
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 50 deletions.
150 changes: 106 additions & 44 deletions src/common/analyzer/rag_analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ static const String REGEX_SPLIT_CHAR = R"#(([ ,\.<>/?;'\[\]\`!@#$%^&*$$\{\}\|_+=
static const String NLTK_TOKENIZE_PATTERN =
R"((?:\-{2,}|\.{2,}|(?:\.\s){2,}\.)|(?=[^\(\"\`{\[:;&\#\*@\)}\]\-,])\S+?(?=\s|$|(?:[)\";}\]\*:@\'\({\[\?!])|(?:\-{2,}|\.{2,}|(?:\.\s){2,}\.)|,(?=$|\s|(?:[)\";}\]\*:@\'\({\[\?!])|(?:\-{2,}|\.{2,}|(?:\.\s){2,}\.)))|\S)";

static constexpr std::size_t MAX_SENTENCE_LEN = 100;

static inline i32 Encode(i32 freq, i32 idx) {
u32 encoded_value = 0;
if (freq < 0) {
Expand Down Expand Up @@ -598,6 +600,101 @@ void RAGAnalyzer::EnglishNormalize(const Vector<String> &tokens, Vector<String>
}
}

void RAGAnalyzer::TokenizeInner(Vector<String> &res, const String &L) {
auto [tks, s] = MaxForward(L);
auto [tks1, s1] = MaxBackward(L);

Vector<int> diff(std::max(tks.size(), tks1.size()), 0);
for (std::size_t i = 0; i < std::min(tks.size(), tks1.size()); ++i) {
if (tks[i] != tks1[i]) {
diff[i] = 1;
}
}

if (s1 > s) {
tks = tks1;
}

std::size_t i = 0;
while (i < tks.size()) {
std::size_t s = i;
while (s < tks.size() && diff[s] == 0) {
s++;
}
if (s == tks.size()) {
res.push_back(Join(tks, i, tks.size()));
break;
}
if (s > i) {
res.push_back(Join(tks, i, s));
}

std::size_t e = s;
while (e < tks.size() && e - s < 5 && diff[e] == 1) {
e++;
}

Vector<Pair<String, int>> pre_tokens;
Vector<Vector<Pair<String, int>>> token_list;
Vector<String> best_tokens;
double max_score = -100.0F;
DFS(Join(tks, s, e < tks.size() ? e + 1 : e, ""), 0, pre_tokens, token_list, best_tokens, max_score, false);
// Vector<Pair<Vector<String>, double>> sorted_tokens;
// SortTokens(token_list, sorted_tokens);
// res.push_back(Join(sorted_tokens[0].first, 0));
res.push_back(Join(best_tokens, 0));
i = e + 1;
}
}

void RAGAnalyzer::SplitLongText(const String &L, u32 length, Vector<String> &sublines) {
u32 slice_count = length / MAX_SENTENCE_LEN + 1;
sublines.reserve(slice_count);
std::size_t last_sentence_start = 0;
std::size_t next_sentence_start = 0;
for (unsigned i = 0; i < slice_count; ++i) {
next_sentence_start = MAX_SENTENCE_LEN * (i + 1) - 5;
if (next_sentence_start + 5 < length) {
std::size_t sentence_length = MAX_SENTENCE_LEN * (i + 1) + 5 > length ? length - next_sentence_start : 10;
String substr = UTF8Substr(L, next_sentence_start, sentence_length);
auto [tks, s] = MaxForward(substr);
auto [tks1, s1] = MaxBackward(substr);
Vector<int> diff(std::max(tks.size(), tks1.size()), 0);
for (std::size_t j = 0; j < std::min(tks.size(), tks1.size()); ++j) {
if (tks[j] != tks1[j]) {
diff[j] = 1;
}
}

if (s1 > s) {
tks = tks1;
}
std::size_t start = 0;
std::size_t forward_same_len = 0;
while (start < tks.size() && diff[start] == 0) {
forward_same_len += UTF8Length(tks[start]);
start++;
}
if (forward_same_len == 0) {
std::size_t end = tks.size() - 1;
std::size_t backward_same_len = 0;
while (end >= 0 && diff[end] == 0) {
backward_same_len += UTF8Length(tks[end]);
end--;
}
next_sentence_start += sentence_length - backward_same_len;
} else
next_sentence_start += forward_same_len;
} else
next_sentence_start = length;
if (next_sentence_start == last_sentence_start)
continue;
String str = UTF8Substr(L, last_sentence_start, next_sentence_start - last_sentence_start);
sublines.push_back(str);
last_sentence_start = next_sentence_start;
}
}

String RAGAnalyzer::Tokenize(const String &line) {
String str1 = StrQ2B(line);
String strline;
Expand Down Expand Up @@ -630,54 +727,19 @@ String RAGAnalyzer::Tokenize(const String &line) {
Vector<String> arr;
Split(strline, regex_split_pattern_, arr, true);
for (const auto &L : arr) {
if (UTF8Length(L) < 2 || RE2::PartialMatch(L, pattern2_) || RE2::PartialMatch(L, pattern3_)) { //[a-z\\.-]+$ [0-9\\.-]+$
auto length = UTF8Length(L);
if (length < 2 || RE2::PartialMatch(L, pattern2_) || RE2::PartialMatch(L, pattern3_)) { //[a-z\\.-]+$ [0-9\\.-]+$
res.push_back(L);
continue;
}
auto [tks, s] = MaxForward(L);
auto [tks1, s1] = MaxBackward(L);

Vector<int> diff(std::max(tks.size(), tks1.size()), 0);
for (std::size_t i = 0; i < std::min(tks.size(), tks1.size()); ++i) {
if (tks[i] != tks1[i]) {
diff[i] = 1;
if (length > MAX_SENTENCE_LEN) {
Vector<String> sublines;
SplitLongText(L, length, sublines);
for (auto &l : sublines) {
TokenizeInner(res, l);
}
}

if (s1 > s) {
tks = tks1;
}

std::size_t i = 0;
while (i < tks.size()) {
std::size_t s = i;
while (s < tks.size() && diff[s] == 0) {
s++;
}
if (s == tks.size()) {
res.push_back(Join(tks, i, tks.size()));
break;
}
if (s > i) {
res.push_back(Join(tks, i, s));
}

std::size_t e = s;
while (e < tks.size() && e - s < 5 && diff[e] == 1) {
e++;
}

Vector<Pair<String, int>> pre_tokens;
Vector<Vector<Pair<String, int>>> token_list;
Vector<String> best_tokens;
double max_score = 0.0F;
DFS(Join(tks, s, e < tks.size() ? e + 1 : e, ""), 0, pre_tokens, token_list, best_tokens, max_score, false);
// Vector<Pair<Vector<String>, double>> sorted_tokens;
// SortTokens(token_list, sorted_tokens);
// res.push_back(Join(sorted_tokens[0].first, 0));
res.push_back(Join(best_tokens, 0));
i = e + 1;
}
} else
TokenizeInner(res, L);
}

Vector<String> normalize_res;
Expand Down
5 changes: 4 additions & 1 deletion src/common/analyzer/rag_analyzer.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ public:
protected:
int AnalyzeImpl(const Term &input, void *data, HookType func) override;

protected:
private:
static constexpr float DENOMINATOR = 1000000;

Expand Down Expand Up @@ -82,6 +81,10 @@ private:
double &max_score,
bool memo_all);

void TokenizeInner(Vector<String> &res, const String &L);

void SplitLongText(const String &L, u32 length, Vector<String> &sublines);

String Merge(const String &tokens);

void EnglishNormalize(const Vector<String> &tokens, Vector<String> &res);
Expand Down
10 changes: 5 additions & 5 deletions src/storage/invertedindex/memory_indexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,15 +431,15 @@ void MemoryIndexer::TupleListToIndexFile(UniquePtr<SortMergerTermTuple<TermTuple
doc_pos_list_size = temp_term_tuple->Size();
term_length = temp_term_tuple->term_.size();

if (term_length >= MAX_TUPLE_LENGTH) {
continue;
}

if (count < temp_term_tuple->Size()) {
UnrecoverableError("Unexpected error in TupleListToIndexFile");
}

count -= temp_term_tuple->Size();

if (term_length >= MAX_TUPLE_LENGTH) {
continue;
}

std::string_view term = std::string_view(temp_term_tuple->term_);

if (term != last_term) {
Expand Down

0 comments on commit bcde525

Please sign in to comment.