Skip to content

Commit

Permalink
Fix rag image strategies (#2519)
Browse files Browse the repository at this point in the history
  • Loading branch information
lferran authored Oct 7, 2024
1 parent de51bac commit 7595a2d
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 52 deletions.
38 changes: 24 additions & 14 deletions nucliadb/src/nucliadb/search/search/chat/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,47 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import base64
from typing import Optional

from nucliadb.common.ids import ParagraphId
from nucliadb.search import SERVICE_NAME
from nucliadb_models.search import Image
from nucliadb_utils.utilities import get_storage


async def get_page_image(kbid: str, paragraph_id: str, page: int) -> Image:
async def get_page_image(kbid: str, paragraph_id: ParagraphId, page_number: int) -> Optional[Image]:
storage = await get_storage(service_name=SERVICE_NAME)

rid, field_type_letter, field_id, _ = paragraph_id.split("/")[:4]

sf = storage.file_extracted(
kbid, rid, field_type_letter, field_id, f"generated/extracted_images_{page}.png"
kbid=kbid,
uuid=paragraph_id.rid,
field_type=paragraph_id.field_id.type,
field=paragraph_id.field_id.key,
key=f"generated/extracted_images_{page_number}.png",
)
image_bytes = (await sf.storage.downloadbytes(sf.bucket, sf.key)).read()
if not image_bytes:
return None
image = Image(
b64encoded=base64.b64encode((await sf.storage.downloadbytes(sf.bucket, sf.key)).read()).decode(),
b64encoded=base64.b64encode(image_bytes).decode(),
content_type="image/png",
)

return image


async def get_paragraph_image(kbid: str, paragraph_id: str, reference: str) -> Image:
async def get_paragraph_image(kbid: str, paragraph_id: ParagraphId, reference: str) -> Optional[Image]:
storage = await get_storage(service_name=SERVICE_NAME)

rid, field_type_letter, field_id, _ = paragraph_id.split("/")[:4]

sf = storage.file_extracted(kbid, rid, field_type_letter, field_id, f"generated/{reference}")
sf = storage.file_extracted(
kbid=kbid,
uuid=paragraph_id.rid,
field_type=paragraph_id.field_id.type,
field=paragraph_id.field_id.key,
key=f"generated/{reference}",
)
image_bytes = (await sf.storage.downloadbytes(sf.bucket, sf.key)).read()
if not image_bytes:
return None
image = Image(
b64encoded=base64.b64encode((await sf.storage.downloadbytes(sf.bucket, sf.key)).read()).decode(),
b64encoded=base64.b64encode(image_bytes).decode(),
content_type="image/png",
)

return image
108 changes: 74 additions & 34 deletions nucliadb/src/nucliadb/search/search/chat/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from nucliadb.ingest.fields.conversation import Conversation
from nucliadb.ingest.orm.knowledgebox import KnowledgeBox as KnowledgeBoxORM
from nucliadb.ingest.orm.resource import FIELD_TYPE_STR_TO_PB
from nucliadb.search import logger
from nucliadb.search.search import cache
from nucliadb.search.search.chat.images import get_page_image, get_paragraph_image
from nucliadb.search.search.paragraphs import get_paragraph_text
Expand Down Expand Up @@ -319,11 +320,11 @@ async def extend_prompt_context_with_metadata(
for text_block_id in context.text_block_ids():
try:
text_block_ids.append(parse_text_block_id(text_block_id))
except ValueError:
except ValueError: # pragma: no cover
# Some text block ids are not paragraphs nor fields, so they are skipped
# (e.g. USER_CONTEXT_0, when the user provides extra context)
continue
if len(text_block_ids) == 0:
if len(text_block_ids) == 0: # pragma: no cover
return

if MetadataExtensionType.ORIGIN in strategy.types:
Expand Down Expand Up @@ -385,7 +386,7 @@ async def _get_labels(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, list[tu
for fc in pb_basic.computedmetadata.field_classifications:
if fc.field.field == fid.key and fc.field.field_type == fid.pb_type:
for classif in fc.classifications:
if classif.cancelled_by_user:
if classif.cancelled_by_user: # pragma: no cover
continue
labels.add((classif.labelset, classif.label))
return _id, list(labels)
Expand Down Expand Up @@ -482,15 +483,15 @@ async def field_extension_prompt_context(
try:
fid = FieldId.from_string(f"{resource_uuid}/{field_id.strip('/')}")
extend_field_ids.append(fid)
except ValueError:
except ValueError: # pragma: no cover
# Invalid field id, skiping
continue

tasks = [get_field_extracted_text(kbid, fid) for fid in extend_field_ids]
field_extracted_texts = await run_concurrently(tasks)

for result in field_extracted_texts:
if result is None:
if result is None: # pragma: no cover
continue
field, extracted_text = result
# First off, remove the text block ids from paragraphs that belong to
Expand Down Expand Up @@ -572,13 +573,13 @@ async def get_field_paragraphs_list(
Modifies the paragraphs list by adding the paragraph ids of the field, sorted by position.
"""
resource = await cache.get_resource(kbid, field.rid)
if resource is None:
if resource is None: # pragma: no cover
return
field_obj: Field = await resource.get_field(key=field.key, type=field.pb_type, load=False)
field_metadata: Optional[resources_pb2.FieldComputedMetadata] = await field_obj.get_field_metadata(
force=True
)
if field_metadata is None:
if field_metadata is None: # pragma: no cover
return
for paragraph in field_metadata.metadata.paragraphs:
paragraphs.append(
Expand Down Expand Up @@ -630,7 +631,7 @@ async def neighbouring_paragraphs_prompt_context(
)
)
)
if not paragraph_ops:
if not paragraph_ops: # pragma: no cover
return

results: list[tuple[ParagraphId, str]] = await asyncio.gather(*paragraph_ops)
Expand Down Expand Up @@ -782,33 +783,64 @@ async def build(
return context, context_order, context_images

async def _build_context_images(self, context: CappedPromptContext) -> None:
page_count = 5
gather_pages = False
gather_tables = False
if self.image_strategies is not None:
for strategy in self.image_strategies:
if strategy.name == ImageRagStrategyName.PAGE_IMAGE:
strategy = cast(PageImageStrategy, strategy)
gather_pages = True
if strategy.count is not None and strategy.count > 0:
page_count = strategy.count
elif strategy.name == ImageRagStrategyName.TABLES:
strategy = cast(TableImageStrategy, strategy)
gather_tables = True
if self.image_strategies is None or len(self.image_strategies) == 0:
# Nothing to do
return

page_image_strategy: Optional[PageImageStrategy] = None
max_page_images = 5
table_image_strategy: Optional[TableImageStrategy] = None
for strategy in self.image_strategies:
if strategy.name == ImageRagStrategyName.PAGE_IMAGE:
page_image_strategy = cast(PageImageStrategy, strategy)
if page_image_strategy.count is not None:
max_page_images = page_image_strategy.count
elif strategy.name == ImageRagStrategyName.TABLES:
table_image_strategy = cast(TableImageStrategy, strategy)

page_images_added = 0
for paragraph in self.ordered_paragraphs:
if paragraph.page_with_visual and paragraph.position:
if gather_pages and paragraph.position.page_number and len(context.images) < page_count:
field = "/".join(paragraph.id.split("/")[:3])
page = paragraph.position.page_number
page_id = f"{field}/{page}"
if page_id not in context.images:
context.images[page_id] = await get_page_image(self.kbid, paragraph.id, page)
# Only send tables if enabled by strategy, by default, send paragraph images
send_images = (gather_tables and paragraph.is_a_table) or not paragraph.is_a_table
if send_images and paragraph.reference and paragraph.reference != "":
image = paragraph.reference
context.images[paragraph.id] = await get_paragraph_image(self.kbid, paragraph.id, image)
pid = ParagraphId.from_string(paragraph.id)
paragraph_page_number = get_paragraph_page_number(paragraph)
if (
page_image_strategy is not None
and page_images_added < max_page_images
and paragraph_page_number is not None
):
# page_image_id: rid/f/myfield/0
page_image_id = "/".join([pid.field_id.full(), str(paragraph_page_number)])
if page_image_id not in context.images:
image = await get_page_image(self.kbid, pid, paragraph_page_number)
if image is not None:
context.images[page_image_id] = image
page_images_added += 1
else:
logger.warning(
f"Could not retrieve image for paragraph from storage",
extra={
"kbid": self.kbid,
"paragraph": pid.full(),
"page_number": paragraph_page_number,
},
)
if (
table_image_strategy is not None
and paragraph.is_a_table
and paragraph.reference is not None
and paragraph.reference != ""
):
pimage = await get_paragraph_image(self.kbid, pid, paragraph.reference)
if pimage is not None:
context.images[paragraph.id] = pimage
else:
logger.warning(
f"Could not retrieve table image for paragraph from storage",
extra={
"kbid": self.kbid,
"paragraph": pid.full(),
"reference": paragraph.reference,
},
)

async def _build_context(self, context: CappedPromptContext) -> None:
if self.strategies is None or len(self.strategies) == 0:
Expand All @@ -830,7 +862,7 @@ async def _build_context(self, context: CappedPromptContext) -> None:
field_extension = cast(FieldExtensionStrategy, strategy)
elif strategy.name == RagStrategyName.FULL_RESOURCE:
full_resource = cast(FullResourceStrategy, strategy)
if self.resource:
if self.resource: # pragma: no cover
# When the retrieval is scoped to a specific resource
# the full resource strategy only includes that resource
full_resource.count = 1
Expand Down Expand Up @@ -875,6 +907,14 @@ async def _build_context(self, context: CappedPromptContext) -> None:
await extend_prompt_context_with_metadata(context, self.kbid, metadata_extension)


def get_paragraph_page_number(paragraph: FindParagraph) -> Optional[int]:
if not paragraph.page_with_visual:
return None
if paragraph.position is None:
return None
return paragraph.position.page_number


@dataclass
class ExtraCharsParagraph:
title: str
Expand Down
24 changes: 24 additions & 0 deletions nucliadb/tests/nucliadb/integration/test_ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,3 +896,27 @@ async def test_ask_fails_with_answer_json_schema_too_big(
resp.json()["detail"]
== "Answer JSON schema with too many properties generated too many prequeries"
)


async def test_rag_image_rag_strategies(
nucliadb_reader: AsyncClient,
knowledgebox: str,
resources: list[str],
):
resp = await nucliadb_reader.post(
f"/kb/{knowledgebox}/ask",
headers={"X-Synchronous": "True"},
json={
"query": "title",
"rag_image_strategies": [
{
"name": "page_image",
"count": 2,
},
{
"name": "tables",
},
],
},
)
assert resp.status_code == 200, resp.text
35 changes: 33 additions & 2 deletions nucliadb/tests/search/unit/search/test_chat_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@
FindRequest,
FindResource,
HierarchyResourceStrategy,
Image,
KnowledgeboxFindResults,
MetadataExtensionStrategy,
MetadataExtensionType,
MinScore,
PageImageStrategy,
PreQuery,
TableImageStrategy,
)
from nucliadb_protos import resources_pb2 as rpb2

Expand Down Expand Up @@ -214,10 +217,10 @@ def find_results():
facets={},
resources={
"resource1": _create_find_result(
"resource1/a/title", "Resource 1", SCORE_TYPE.BOTH, order=1
"resource1/a/title/0-10", "Resource 1", SCORE_TYPE.BOTH, order=1
),
"resource2": _create_find_result(
"resource2/a/title", "Resource 2", SCORE_TYPE.VECTOR, order=2
"resource2/a/title/0-10", "Resource 2", SCORE_TYPE.VECTOR, order=2
),
},
min_score=MinScore(semantic=-1),
Expand Down Expand Up @@ -529,3 +532,31 @@ def test_get_ordered_paragraphs():
assert ordered_paragraphs[3].id == "prequery-2-result/f/f1/10-20"
assert ordered_paragraphs[4].id == "prequery-1-result/f/f1/0-10"
assert ordered_paragraphs[5].id == "prequery-1-result/f/f1/10-20"


@pytest.mark.asyncio
async def test_prompt_context_image_context_builder(
find_results: KnowledgeboxFindResults,
):
builder = chat_prompt.PromptContextBuilder(
kbid="kbid",
main_results=find_results,
user_context=["Carrots are orange"],
image_strategies=[PageImageStrategy(count=10), TableImageStrategy()],
)
module = "nucliadb.search.search.chat.prompt"
with (
mock.patch(f"{module}.get_paragraph_page_number", return_value=1),
mock.patch(
f"{module}.get_page_image",
return_value=Image(b64encoded="page_image_data", content_type="image/png"),
),
mock.patch(
f"{module}.get_paragraph_image",
return_value=Image(b64encoded="table_image_data", content_type="image/png"),
),
):
context = chat_prompt.CappedPromptContext(max_size=int(1e6))
await builder._build_context_images(context)
assert len(context.output) == 0
assert len(context.images) == 2
4 changes: 2 additions & 2 deletions nucliadb_models/src/nucliadb_models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,11 +1131,11 @@ class PreQueriesStrategy(RagStrategy):


class TableImageStrategy(ImageRagStrategy):
name: Literal["tables"]
name: Literal["tables"] = "tables"


class PageImageStrategy(ImageRagStrategy):
name: Literal["page_image"]
name: Literal["page_image"] = "page_image"
count: Optional[int] = Field(
default=None,
title="Count",
Expand Down

0 comments on commit 7595a2d

Please sign in to comment.