Skip to content

Commit

Permalink
py lint (#12102)
Browse files Browse the repository at this point in the history
Signed-off-by: -LAN- <[email protected]>
Co-authored-by: -LAN- <[email protected]>
  • Loading branch information
JohnJyong and laipz8200 authored Dec 25, 2024
1 parent bb35818 commit 84ac004
Show file tree
Hide file tree
Showing 20 changed files with 262 additions and 208 deletions.
2 changes: 1 addition & 1 deletion api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def upgrade_db():
click.echo(click.style("Starting database migration.", fg="green"))

# run db migration
import flask_migrate
import flask_migrate # type: ignore

flask_migrate.upgrade()

Expand Down
7 changes: 4 additions & 3 deletions api/controllers/console/datasets/datasets_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,14 +413,15 @@ def get(self, dataset_id, document_id):
indexing_runner = IndexingRunner()

try:
response = indexing_runner.indexing_estimate(
estimate_response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
[extract_setting],
data_process_rule_dict,
document.doc_form,
"English",
dataset_id,
)
return estimate_response.model_dump(), 200
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
Expand All @@ -431,7 +432,7 @@ def get(self, dataset_id, document_id):
except Exception as e:
raise IndexingEstimateError(str(e))

return response.model_dump(), 200
return response, 200


class DocumentBatchIndexingEstimateApi(DocumentResource):
Expand Down Expand Up @@ -521,6 +522,7 @@ def get(self, dataset_id, batch):
"English",
dataset_id,
)
return response.model_dump(), 200
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
Expand All @@ -530,7 +532,6 @@ def get(self, dataset_id, batch):
raise ProviderNotInitializeError(ex.description)
except Exception as e:
raise IndexingEstimateError(str(e))
return response.model_dump(), 200


class DocumentBatchIndexingStatusApi(DocumentResource):
Expand Down
22 changes: 14 additions & 8 deletions api/controllers/service_api/dataset/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService


Expand Down Expand Up @@ -67,13 +68,14 @@ def post(self, tenant_id, dataset_id):
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
knowledge_config = KnowledgeConfig(**args)
# validate args
DocumentService.document_create_args_validate(args)
DocumentService.document_create_args_validate(knowledge_config)

try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
knowledge_config=knowledge_config,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
Expand Down Expand Up @@ -122,12 +124,13 @@ def post(self, tenant_id, dataset_id, document_id):
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
DocumentService.document_create_args_validate(args)
knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)

try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
knowledge_config=knowledge_config,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
Expand Down Expand Up @@ -186,12 +189,13 @@ def post(self, tenant_id, dataset_id):
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
args["data_source"] = data_source
# validate args
DocumentService.document_create_args_validate(args)
knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)

try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
knowledge_config=knowledge_config,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
Expand Down Expand Up @@ -245,12 +249,14 @@ def post(self, tenant_id, dataset_id, document_id):
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
DocumentService.document_create_args_validate(args)

knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)

try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
knowledge_config=knowledge_config,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
Expand Down
14 changes: 7 additions & 7 deletions api/core/indexing_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def indexing_estimate(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
preview_texts = []
preview_texts = [] # type: ignore

total_segments = 0
index_type = doc_form
Expand All @@ -300,13 +300,13 @@ def indexing_estimate(
if len(preview_texts) < 10:
if doc_form and doc_form == "qa_model":
preview_detail = QAPreviewDetail(
question=document.page_content, answer=document.metadata.get("answer")
question=document.page_content, answer=document.metadata.get("answer") or ""
)
preview_texts.append(preview_detail)
else:
preview_detail = PreviewDetail(content=document.page_content)
preview_detail = PreviewDetail(content=document.page_content) # type: ignore
if document.children:
preview_detail.child_chunks = [child.page_content for child in document.children]
preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore
preview_texts.append(preview_detail)

# delete image files and related db records
Expand All @@ -325,7 +325,7 @@ def indexing_estimate(

if doc_form and doc_form == "qa_model":
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore

def _extract(
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
Expand Down Expand Up @@ -454,7 +454,7 @@ def _get_splitter(
embedding_model_instance=embedding_model_instance,
)

return character_splitter
return character_splitter # type: ignore

def _split_to_documents_for_estimate(
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
Expand Down Expand Up @@ -535,7 +535,7 @@ def _load(
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
)
create_keyword_thread.start()

Expand Down
129 changes: 65 additions & 64 deletions api/core/rag/datasource/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,78 +258,79 @@ def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegme
include_segment_ids = []
segment_child_map = {}
for document in documents:
document_id = document.metadata["document_id"]
document_id = document.metadata.get("document_id")
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_index_node_id = document.metadata["doc_id"]
result = (
db.session.query(ChildChunk, DocumentSegment)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
ChildChunk.index_node_id == child_index_node_id,
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_index_node_id = document.metadata.get("doc_id")
result = (
db.session.query(ChildChunk, DocumentSegment)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
ChildChunk.index_node_id == child_index_node_id,
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
)
.first()
)
.first()
)
if result:
child_chunk, segment = result
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.append(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
if result:
child_chunk, segment = result
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.append(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
continue
else:
continue
else:
index_node_id = document.metadata["doc_id"]
index_node_id = document.metadata["doc_id"]

segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
)
.first()
)

if not segment:
continue
include_segment_ids.append(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score", None),
}
if not segment:
continue
include_segment_ids.append(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score", None),
}

records.append(record)
records.append(record)
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None)
Expand Down
37 changes: 19 additions & 18 deletions api/core/rag/docstore/dataset_docstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,26 +122,27 @@ def add_documents(self, docs: Sequence[Document], allow_update: bool = True, sav
db.session.add(segment_document)
db.session.flush()
if save_child:
for postion, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=postion,
index_node_id=child.metadata["doc_id"],
index_node_hash=child.metadata["doc_hash"],
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
if doc.children:
for postion, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=postion,
index_node_id=child.metadata.get("doc_id"),
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
else:
segment_document.content = doc.page_content
if doc.metadata.get("answer"):
segment_document.answer = doc.metadata.pop("answer", "")
segment_document.index_node_hash = doc.metadata["doc_hash"]
segment_document.index_node_hash = doc.metadata.get("doc_hash")
segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens
if save_child and doc.children:
Expand All @@ -160,8 +161,8 @@ def add_documents(self, docs: Sequence[Document], allow_update: bool = True, sav
document_id=self._document_id,
segment_id=segment_document.id,
position=position,
index_node_id=child.metadata["doc_id"],
index_node_hash=child.metadata["doc_hash"],
index_node_id=child.metadata.get("doc_id"),
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/extractor/excel_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional, cast

import pandas as pd
from openpyxl import load_workbook
from openpyxl import load_workbook # type: ignore

from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/index_processor/index_processor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ def _get_splitter(
embedding_model_instance=embedding_model_instance,
)

return character_splitter
return character_splitter # type: ignore
Loading

0 comments on commit 84ac004

Please sign in to comment.