From 2f70567be83dff950ed2ac70b71eac11565d177d Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 27 Dec 2024 11:24:06 +0800 Subject: [PATCH] fix: update keyword extraction to remove optional parameter and improve type casting Signed-off-by: -LAN- --- .../keyword/jieba/jieba_keyword_table_handler.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index 8b17e8dc0a3762..a6214d955b1ddd 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -1,5 +1,5 @@ import re -from typing import Optional +from typing import Optional, cast class JiebaKeywordTableHandler: @@ -8,18 +8,20 @@ def __init__(self): from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS - jieba.analyse.default_tfidf.stop_words = STOPWORDS + jieba.analyse.default_tfidf.stop_words = STOPWORDS # type: ignore def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]: """Extract keywords with JIEBA tfidf.""" - import jieba # type: ignore + import jieba.analyse # type: ignore keywords = jieba.analyse.extract_tags( sentence=text, topK=max_keywords_per_chunk, ) + # jieba.analyse.extract_tags returns list[Any] when withFlag is False by default. + keywords = cast(list[str], keywords) - return set(self._expand_tokens_with_subtokens(keywords)) + return set(self._expand_tokens_with_subtokens(set(keywords))) def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: """Get subtokens from a list of tokens., filtering for stopwords."""