diff --git a/nucliadb/src/nucliadb/ingest/orm/brain.py b/nucliadb/src/nucliadb/ingest/orm/brain.py index 4ffb4bd0e1..0b43c574b4 100644 --- a/nucliadb/src/nucliadb/ingest/orm/brain.py +++ b/nucliadb/src/nucliadb/ingest/orm/brain.py @@ -498,24 +498,54 @@ def process_field_metadata( to=relation_node_label, ) ) - - for klass_entity, _ in metadata.positions.items(): - labels["e"].add(klass_entity) - entity_array = klass_entity.split("/") - if len(entity_array) == 1: - raise AttributeError(f"Entity should be with type {klass_entity}") - elif len(entity_array) > 1: - klass = entity_array[0] - entity = "/".join(entity_array[1:]) - relation_node_entity = RelationNode( - value=entity, ntype=RelationNode.NodeType.ENTITY, subtype=klass - ) - rel = Relation( - relation=Relation.ENTITY, - source=relation_node_document, - to=relation_node_entity, - ) - self.brain.relations.append(rel) + # Data Augmentation + Processor entities + use_legacy_entities = True + for data_augmentation_task_id, entities in metadata.entities.items(): + # If we recieved the entities from the processor here, we don't want to use the legacy entities + # TODO: Remove this when processor doesn't use this anymore + if data_augmentation_task_id == "processor": + use_legacy_entities = False + + for ent in entities.entities: + entity_text = ent.text + entity_label = ent.label + # Seems like we don't care about where the entity is in the text + # entity_positions = entity.positions + labels["e"].add( + f"{entity_label}/{entity_text}" + ) # Add data_augmentation_task_id as a prefix? + relation_node_entity = RelationNode( + value=entity_text, + ntype=RelationNode.NodeType.ENTITY, + subtype=entity_label, + ) + rel = Relation( + relation=Relation.ENTITY, + source=relation_node_document, + to=relation_node_entity, + ) + self.brain.relations.append(rel) + + # Legacy processor entities + # TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message + if use_legacy_entities: + for klass_entity, _ in metadata.positions.items(): + labels["e"].add(klass_entity) + entity_array = klass_entity.split("/") + if len(entity_array) == 1: + raise AttributeError(f"Entity should be with type {klass_entity}") + elif len(entity_array) > 1: + klass = entity_array[0] + entity = "/".join(entity_array[1:]) + relation_node_entity = RelationNode( + value=entity, ntype=RelationNode.NodeType.ENTITY, subtype=klass + ) + rel = Relation( + relation=Relation.ENTITY, + source=relation_node_document, + to=relation_node_entity, + ) + self.brain.relations.append(rel) def apply_field_labels( self, diff --git a/nucliadb/src/nucliadb/ingest/orm/resource.py b/nucliadb/src/nucliadb/ingest/orm/resource.py index b9f29f9261..3ca1e6234d 100644 --- a/nucliadb/src/nucliadb/ingest/orm/resource.py +++ b/nucliadb/src/nucliadb/ingest/orm/resource.py @@ -23,7 +23,7 @@ import logging from concurrent.futures import ThreadPoolExecutor from functools import partial -from typing import TYPE_CHECKING, Any, AsyncIterator, Optional, Type +from typing import TYPE_CHECKING, Any, AsyncIterator, MutableMapping, Optional, Type from nucliadb.common import datamanagers from nucliadb.common.datamanagers.resources import KB_RESOURCE_SLUG @@ -890,7 +890,7 @@ async def iterate_sentences( entities: dict[str, str] = {} if enabled_metadata.entities: - entities.update(field_metadata.ner) + _update_entities_dict(entities, field_metadata) precomputed_vectors = {} if vo is not None: @@ -996,7 +996,7 @@ async def iterate_paragraphs( entities: dict[str, str] = {} if enabled_metadata.entities: - entities.update(field_metadata.ner) + _update_entities_dict(entities, field_metadata) if extracted_text is not None: if subfield is not None: @@ -1075,7 +1075,7 @@ async def iterate_fields(self, enabled_metadata: EnabledMetadata) -> AsyncIterat if enabled_metadata.entities: metadata.ClearField("entities") - metadata.entities.update(splitted_metadata.ner) + _update_entities_dict(metadata.entities, splitted_metadata) pb_field = TrainField() pb_field.uuid = self.uuid @@ -1119,7 +1119,7 @@ async def generate_train_resource(self, enabled_metadata: EnabledMetadata) -> Tr metadata.labels.field.extend(splitted_metadata.classifications) if enabled_metadata.entities: - metadata.entities.update(splitted_metadata.ner) + _update_entities_dict(metadata.entities, splitted_metadata) pb_resource = TrainResource() pb_resource.uuid = self.uuid @@ -1254,3 +1254,23 @@ def extract_field_metadata_languages( for _, splitted_metadata in field_metadata.metadata.split_metadata.items(): languages.add(splitted_metadata.language) return list(languages) + + +def _update_entities_dict(target_entites_dict: MutableMapping[str, str], field_metadata: FieldMetadata): + """ + Update the entities dict with the entities from the field metadata. + Method created to ease the transition from legacy ner field to new entities field. + """ + # Data Augmentation + Processor entities + # This will overwrite entities detected from more than one data augmentation task + # TODO: Change TrainMetadata proto to accept multiple entities with the same text + entity_map = { + entity.text: entity.label + for data_augmentation_task_id, entities_wrapper in field_metadata.entities.items() + for entity in entities_wrapper.entities + } + target_entites_dict.update(entity_map) + + # Legacy processor entities + # TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message + target_entites_dict.update(field_metadata.ner) diff --git a/nucliadb/src/nucliadb/search/search/chat/prompt.py b/nucliadb/src/nucliadb/search/search/chat/prompt.py index 7107a39851..52a35b4784 100644 --- a/nucliadb/src/nucliadb/search/search/chat/prompt.py +++ b/nucliadb/src/nucliadb/search/search/chat/prompt.py @@ -413,6 +413,12 @@ async def _get_ners(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, dict[str, field = await resource.get_field(fid.key, fid.pb_type, load=False) fcm = await field.get_field_metadata() if fcm is not None: + # Data Augmentation + Processor entities + for data_aumgentation_task_id, entities_wrapper in fcm.metadata.entities.items(): + for entity in entities_wrapper.entities: + ners.setdefault(entity.label, set()).add(entity.text) + # Legacy processor entities + # TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message for token, family in fcm.metadata.ner.items(): ners.setdefault(family, set()).add(token) return _id, ners diff --git a/nucliadb/src/nucliadb/train/generators/token_classifier.py b/nucliadb/src/nucliadb/train/generators/token_classifier.py index 4c5b039d6f..179742103d 100644 --- a/nucliadb/src/nucliadb/train/generators/token_classifier.py +++ b/nucliadb/src/nucliadb/train/generators/token_classifier.py @@ -139,6 +139,36 @@ async def get_field_text( field_metadata = await field_obj.get_field_metadata() # Check computed definition of entities if field_metadata is not None: + # Data Augmentation + Processor entities + for data_augmentation_task_id, entities in field_metadata.metadata.entities.items(): + for entity in entities.entities: + entity_text = entity.text + entity_label = entity.label + entity_positions = entity.positions + if entity_label in valid_entity_groups: + split_ners[MAIN].setdefault(entity_label, {}).setdefault(entity_text, []) + for position in entity_positions: + split_ners[MAIN][entity_label][entity_text].append( + (position.start, position.end) + ) + + for split, split_metadata in field_metadata.split_metadata.items(): + for data_augmentation_task_id, entities in split_metadata.entities.items(): + for entity in entities.entities: + entity_text = entity.text + entity_label = entity.label + entity_positions = entity.positions + if entity_label in valid_entity_groups: + split_ners.setdefault(split, {}).setdefault(entity_label, {}).setdefault( + entity_text, [] + ) + for position in entity_positions: + split_ners[split][entity_label][entity_text].append( + (position.start, position.end) + ) + + # Legacy processor entities + # TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message for entity_key, positions in field_metadata.metadata.positions.items(): entities = entity_key.split("/") entity_group = entities[0] diff --git a/nucliadb/tests/ingest/fixtures.py b/nucliadb/tests/ingest/fixtures.py index 780f5ca3eb..494c4dc870 100644 --- a/nucliadb/tests/ingest/fixtures.py +++ b/nucliadb/tests/ingest/fixtures.py @@ -387,9 +387,24 @@ def make_field_metadata(field_id): ex1.metadata.metadata.last_extract.FromDatetime(datetime.now()) ex1.metadata.metadata.last_summary.FromDatetime(datetime.now()) ex1.metadata.metadata.thumbnail.CopyFrom(THUMBNAIL) - ex1.metadata.metadata.positions["ENTITY/document"].entity = "document" - ex1.metadata.metadata.positions["ENTITY/document"].position.extend( - [rpb.Position(start=0, end=5), rpb.Position(start=13, end=18)] + # Data Augmentation + Processor entities + ex1.metadata.metadata.entities["processor"].entities.extend( + [ + rpb.FieldEntity( + text="document", + label="ENTITY", + positions=[rpb.Position(start=0, end=5), rpb.Position(start=13, end=18)], + ), + ] + ) + ex1.metadata.metadata.entities["my-task-id"].entities.extend( + [ + rpb.FieldEntity( + text="document", + label="NOUN", + positions=[rpb.Position(start=0, end=5), rpb.Position(start=13, end=18)], + ), + ] ) return ex1 @@ -511,7 +526,12 @@ def broker_resource( fcm.metadata.metadata.last_index.FromDatetime(datetime.now()) fcm.metadata.metadata.last_understanding.FromDatetime(datetime.now()) fcm.metadata.metadata.last_extract.FromDatetime(datetime.now()) - fcm.metadata.metadata.ner["Ramon"] = "PERSON" + fcm.metadata.metadata.entities["processor"].entities.extend( + [rpb.FieldEntity(text="Ramon", label="PERSON")] + ) + fcm.metadata.metadata.entities["my-data-augmentation"].entities.extend( + [rpb.FieldEntity(text="Ramon", label="CTO")] + ) c1 = rpb.Classification() c1.label = "label1" diff --git a/nucliadb/tests/ingest/integration/ingest/test_ingest.py b/nucliadb/tests/ingest/integration/ingest/test_ingest.py index bf6ec4c936..a5fef32d47 100644 --- a/nucliadb/tests/ingest/integration/ingest/test_ingest.py +++ b/nucliadb/tests/ingest/integration/ingest/test_ingest.py @@ -50,6 +50,7 @@ ExtractedVectorsWrapper, FieldComputedMetadata, FieldComputedMetadataWrapper, + FieldEntity, FieldID, FieldMetadata, FieldQuestionAnswerWrapper, @@ -166,6 +167,13 @@ async def test_ingest_messages_autocommit(kbid: str, processor): fcm.metadata.metadata.last_index.FromDatetime(datetime.now()) fcm.metadata.metadata.last_understanding.FromDatetime(datetime.now()) fcm.metadata.metadata.last_extract.FromDatetime(datetime.now()) + # Data Augmentation + Processor entities + fcm.metadata.metadata.entities["processor"].entities.extend( + [FieldEntity(text="Ramon", label="PERSON")] + ) + + # Legacy processor entities + # TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message fcm.metadata.metadata.ner["Ramon"] = "PERSON" c1 = Classification() diff --git a/nucliadb/tests/ingest/integration/ingest/test_relations.py b/nucliadb/tests/ingest/integration/ingest/test_relations.py index 181297fe7e..da4adfc8ca 100644 --- a/nucliadb/tests/ingest/integration/ingest/test_relations.py +++ b/nucliadb/tests/ingest/integration/ingest/test_relations.py @@ -25,6 +25,7 @@ from nucliadb_protos.resources_pb2 import ( Classification, FieldComputedMetadataWrapper, + FieldEntity, FieldID, FieldText, FieldType, @@ -140,6 +141,15 @@ async def test_ingest_field_metadata_relation_extraction( field="title", ) ) + # Data Augmentation + Processor entities + fcmw.metadata.metadata.entities["my-task-id"].entities.extend( + [ + FieldEntity(text="value-3", label="subtype-3"), + FieldEntity(text="value-4", label="subtype-4"), + ] + ) + # Legacy processor entities + # TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message fcmw.metadata.metadata.positions["subtype-1/value-1"].entity = "value-1" fcmw.metadata.metadata.positions["subtype-1/value-2"].entity = "value-2" @@ -159,7 +169,18 @@ async def test_ingest_field_metadata_relation_extraction( pb = await storage.get_indexing(index._calls[0][1]) generated_relations = [ - # From ner metadata + # From data augmentation + processor metadata + Relation( + relation=Relation.RelationType.ENTITY, + source=RelationNode(value=rid, ntype=RelationNode.NodeType.RESOURCE), + to=RelationNode(value="value-3", ntype=RelationNode.NodeType.ENTITY, subtype="subtype-3"), + ), + Relation( + relation=Relation.RelationType.ENTITY, + source=RelationNode(value=rid, ntype=RelationNode.NodeType.RESOURCE), + to=RelationNode(value="value-4", ntype=RelationNode.NodeType.ENTITY, subtype="subtype-4"), + ), + # From legacy ner metadata Relation( relation=Relation.RelationType.ENTITY, source=RelationNode(value=rid, ntype=RelationNode.NodeType.RESOURCE), diff --git a/nucliadb/tests/ingest/integration/orm/test_orm_metadata.py b/nucliadb/tests/ingest/integration/orm/test_orm_metadata.py index fc8bec6170..ee542d6a8c 100644 --- a/nucliadb/tests/ingest/integration/orm/test_orm_metadata.py +++ b/nucliadb/tests/ingest/integration/orm/test_orm_metadata.py @@ -29,6 +29,7 @@ Classification, FieldComputedMetadata, FieldComputedMetadataWrapper, + FieldEntity, FieldID, FieldType, Paragraph, @@ -58,13 +59,18 @@ async def test_create_resource_orm_metadata( p1.classifications.append(cl1) ex1.metadata.metadata.paragraphs.append(p1) ex1.metadata.metadata.classifications.append(cl1) - ex1.metadata.metadata.ner["Ramon"] = "PEOPLE" ex1.metadata.metadata.last_index.FromDatetime(datetime.now()) ex1.metadata.metadata.last_understanding.FromDatetime(datetime.now()) ex1.metadata.metadata.last_extract.FromDatetime(datetime.now()) - ex1.metadata.metadata.positions["document"].entity = "Ramon" - ex1.metadata.metadata.positions["document"].position.extend( - [Position(start=0, end=5), Position(start=23, end=28)] + # Data Augmentation + Processor entities + ex1.metadata.metadata.entities["my-task-id"].entities.extend( + [ + FieldEntity( + text="Ramon", + label="PEOPLE", + positions=[Position(start=0, end=5), Position(start=23, end=28)], + ) + ] ) field_obj: Text = await r.get_field(ex1.field.field, ex1.field.field_type, load=False) @@ -97,7 +103,10 @@ async def test_create_resource_orm_metadata_split( p1.classifications.append(cl1) ex1.metadata.split_metadata["ff1"].paragraphs.append(p1) ex1.metadata.split_metadata["ff1"].classifications.append(cl1) - ex1.metadata.split_metadata["ff1"].ner["Ramon"] = "PEOPLE" + ex1.metadata.split_metadata["ff1"].entities["processor"].entities.extend( + [FieldEntity(text="Ramon", label="PERSON")] + ) + ex1.metadata.split_metadata["ff1"].last_index.FromDatetime(datetime.now()) ex1.metadata.split_metadata["ff1"].last_understanding.FromDatetime(datetime.now()) ex1.metadata.split_metadata["ff1"].last_extract.FromDatetime(datetime.now()) @@ -117,7 +126,9 @@ async def test_create_resource_orm_metadata_split( p1.classifications.append(cl1) ex2.metadata.split_metadata["ff2"].paragraphs.append(p1) ex2.metadata.split_metadata["ff2"].classifications.append(cl1) - ex2.metadata.split_metadata["ff2"].ner["Ramon"] = "PEOPLE" + ex1.metadata.split_metadata["ff1"].entities["processor"].entities.extend( + [FieldEntity(text="Ramon", label="PEOPLE")] + ) ex2.metadata.split_metadata["ff2"].last_index.FromDatetime(datetime.now()) ex2.metadata.split_metadata["ff2"].last_understanding.FromDatetime(datetime.now()) ex2.metadata.split_metadata["ff2"].last_extract.FromDatetime(datetime.now()) diff --git a/nucliadb/tests/ingest/integration/orm/test_orm_resource.py b/nucliadb/tests/ingest/integration/orm/test_orm_resource.py index efba2d9d33..decb643fff 100644 --- a/nucliadb/tests/ingest/integration/orm/test_orm_resource.py +++ b/nucliadb/tests/ingest/integration/orm/test_orm_resource.py @@ -369,7 +369,9 @@ async def test_generate_broker_message( lfcm = [fcm for fcm in bm.field_metadata if fcm.field.field == "link1"][0] assert lfcm.metadata.metadata.links[0] == "https://nuclia.com" assert len(lfcm.metadata.metadata.paragraphs) == 1 - assert len(lfcm.metadata.metadata.positions) == 1 + assert len(lfcm.metadata.metadata.entities["processor"].entities) == 1 + assert len(lfcm.metadata.metadata.entities["my-task-id"].entities) == 1 + assert lfcm.metadata.metadata.HasField("last_index") assert lfcm.metadata.metadata.HasField("last_understanding") assert lfcm.metadata.metadata.HasField("last_extract") @@ -445,7 +447,9 @@ async def test_generate_index_message_contains_all_metadata( assert field in fields_to_be_found fields_to_be_found.remove(field) assert text.text == "MyText" - assert {"/l/labelset1/label1", "/e/ENTITY/document"}.issubset(set(text.labels)) + assert {"/l/labelset1/label1", "/e/ENTITY/document", "/e/NOUN/document"}.issubset( + set(text.labels) + ) if field in ("u/link", "t/text1"): assert "/e/Location/My home" in text.labels diff --git a/nucliadb/tests/nucliadb/integration/search/test_filters.py b/nucliadb/tests/nucliadb/integration/search/test_filters.py index b725944253..01ef6b5e86 100644 --- a/nucliadb/tests/nucliadb/integration/search/test_filters.py +++ b/nucliadb/tests/nucliadb/integration/search/test_filters.py @@ -30,6 +30,7 @@ ExtractedTextWrapper, ExtractedVectorsWrapper, FieldComputedMetadataWrapper, + FieldEntity, FieldID, FieldType, Paragraph, @@ -115,9 +116,13 @@ def broker_message_with_entities(kbid): fmw = FieldComputedMetadataWrapper() fmw.field.CopyFrom(field) family, entity = EntityLabels.DETECTED.split("/") - fmw.metadata.metadata.ner[entity] = family pos = Position(start=60, end=64) - fmw.metadata.metadata.positions[EntityLabels.DETECTED].position.append(pos) + # Data Augmentation + Processor entities + fmw.metadata.metadata.entities["my-task-id"].entities.extend( + [ + FieldEntity(text=entity, label=family, positions=[pos]), + ] + ) par1 = Paragraph() par1.start = 0 diff --git a/nucliadb/tests/nucliadb/integration/test_api.py b/nucliadb/tests/nucliadb/integration/test_api.py index 791c388c16..53c8652af1 100644 --- a/nucliadb/tests/nucliadb/integration/test_api.py +++ b/nucliadb/tests/nucliadb/integration/test_api.py @@ -343,15 +343,17 @@ async def test_extracted_shortened_metadata( fcmw.metadata.metadata.relations.append(relations) fcmw.metadata.split_metadata["split"].relations.append(relations) - # Add some ners - ner = {"Barcelona": "CITY/Barcelona"} - fcmw.metadata.metadata.ner.update(ner) - fcmw.metadata.split_metadata["split"].ner.update(ner) - - # Add some positions - position = rpb.Position(start=1, end=2) - fcmw.metadata.metadata.positions["foo"].position.append(position) - fcmw.metadata.split_metadata["split"].positions["foo"].position.append(position) + # Add some ners with position + fcmw.metadata.metadata.entities["processor"].entities.extend( + [ + rpb.FieldEntity(text="Barcelona", label="CITY", positions=[rpb.Position(start=1, end=2)]), + ] + ) + fcmw.metadata.split_metadata["split"].entities["processor"].entities.extend( + [ + rpb.FieldEntity(text="Barcelona", label="CITY", positions=[rpb.Position(start=1, end=2)]), + ] + ) # Add some classification classification = rpb.Classification(label="foo", labelset="bar") @@ -362,6 +364,7 @@ async def test_extracted_shortened_metadata( await inject_message(nucliadb_grpc, br) + # TODO: Remove ner and positions once fields are removed cropped_fields = ["ner", "positions", "relations", "classifications"] # Check that when 'shortened_metadata' in extracted param fields are cropped diff --git a/nucliadb/tests/nucliadb/integration/test_labels.py b/nucliadb/tests/nucliadb/integration/test_labels.py index 723a075f5f..3c2ab164df 100644 --- a/nucliadb/tests/nucliadb/integration/test_labels.py +++ b/nucliadb/tests/nucliadb/integration/test_labels.py @@ -117,7 +117,10 @@ def broker_resource(knowledgebox: str) -> BrokerMessage: fcm.metadata.metadata.last_index.FromDatetime(datetime.now()) fcm.metadata.metadata.last_understanding.FromDatetime(datetime.now()) fcm.metadata.metadata.last_extract.FromDatetime(datetime.now()) - fcm.metadata.metadata.ner["Ramon"] = "PERSON" + fcm.metadata.metadata.entities["processor"].entities.extend( + [rpb.FieldEntity(text="Ramon", label="PERSON")] + ) + fcm.metadata.metadata.classifications.append(c1) bm.field_metadata.append(fcm) diff --git a/nucliadb/tests/nucliadb/integration/test_pinecone_kb.py b/nucliadb/tests/nucliadb/integration/test_pinecone_kb.py index d4876552d8..5332bd29c4 100644 --- a/nucliadb/tests/nucliadb/integration/test_pinecone_kb.py +++ b/nucliadb/tests/nucliadb/integration/test_pinecone_kb.py @@ -318,7 +318,9 @@ async def _inject_broker_message(nucliadb_grpc: WriterStub, kbid: str, rid: str, fcm.metadata.metadata.last_index.FromDatetime(datetime.now()) fcm.metadata.metadata.last_understanding.FromDatetime(datetime.now()) fcm.metadata.metadata.last_extract.FromDatetime(datetime.now()) - fcm.metadata.metadata.ner["Ramon"] = "PERSON" + fcm.metadata.metadata.entities["processor"].entities.extend( + [rpb.FieldEntity(text="Ramon", label="PERSON")] + ) c1 = rpb.Classification() c1.label = "label1" diff --git a/nucliadb/tests/reader/integration/api/v1/test_reader_resource.py b/nucliadb/tests/reader/integration/api/v1/test_reader_resource.py index ccfa764c29..6191cf16de 100644 --- a/nucliadb/tests/reader/integration/api/v1/test_reader_resource.py +++ b/nucliadb/tests/reader/integration/api/v1/test_reader_resource.py @@ -314,4 +314,15 @@ async def test_get_resource_extracted_metadata(nucliadb_reader: AsyncClient, tes resource = resp.json() metadata = resource["data"]["texts"]["text1"]["extracted"]["metadata"]["metadata"] + + # Check that the processor entity is in the legacy metadata + # TODO: Remove once deprecated fields are removed assert metadata["positions"]["ENTITY/document"]["entity"] == "document" + # Check that we recieved entities in the new fields + assert metadata["entities"]["processor"]["entities"][0]["text"] == "document" + assert metadata["entities"]["processor"]["entities"][0]["label"] == "ENTITY" + assert len(metadata["entities"]["processor"]["entities"][0]["positions"]) == 2 + + assert metadata["entities"]["my-task-id"]["entities"][0]["text"] == "document" + assert metadata["entities"]["my-task-id"]["entities"][0]["label"] == "NOUN" + assert len(metadata["entities"]["my-task-id"]["entities"][0]["positions"]) == 2 diff --git a/nucliadb/tests/search/unit/search/test_chat_prompt.py b/nucliadb/tests/search/unit/search/test_chat_prompt.py index a8d4495c95..e17f00771c 100644 --- a/nucliadb/tests/search/unit/search/test_chat_prompt.py +++ b/nucliadb/tests/search/unit/search/test_chat_prompt.py @@ -419,7 +419,10 @@ async def test_extend_prompt_context_with_metadata(): resource.get_basic = AsyncMock(return_value=basic) field = mock.Mock() fcm = rpb2.FieldComputedMetadata() - fcm.metadata.ner.update({"Barcelona": "LOCATION"}) + fcm.metadata.entities["processor"].entities.extend( + [rpb2.FieldEntity(text="Barcelona", label="LOCATION")] + ) + field.get_field_metadata = AsyncMock(return_value=fcm) resource.get_field = AsyncMock(return_value=field) resource.get_extra = AsyncMock(return_value=extra) diff --git a/nucliadb/tests/train/fixtures.py b/nucliadb/tests/train/fixtures.py index 950c5768b1..dc009a5340 100644 --- a/nucliadb/tests/train/fixtures.py +++ b/nucliadb/tests/train/fixtures.py @@ -37,6 +37,7 @@ from nucliadb_protos.resources_pb2 import ( ExtractedTextWrapper, FieldComputedMetadataWrapper, + FieldEntity, FieldID, FieldType, Paragraph, @@ -189,6 +190,15 @@ def broker_processed_resource(knowledgebox, number, rid) -> BrokerMessage: fcmw.metadata.metadata.paragraphs.append(p1) fcmw.metadata.metadata.paragraphs.append(p2) + # Data Augmentation + Processor entities + fcmw.metadata.metadata.entities["my-task-id"].entities.extend( + [ + FieldEntity(text="Barcelona", label="CITY", positions=[Position(start=43, end=52)]), + FieldEntity(text="Manresa", label="CITY"), + ] + ) + # Legacy processor entities + # TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message # Add a ner with positions fcmw.metadata.metadata.ner.update( { @@ -198,6 +208,7 @@ def broker_processed_resource(knowledgebox, number, rid) -> BrokerMessage: ) fcmw.metadata.metadata.positions["CITY/Barcelona"].entity = "Barcelona" fcmw.metadata.metadata.positions["CITY/Barcelona"].position.append(Position(start=43, end=52)) + message2.field_metadata.append(fcmw) etw = ExtractedTextWrapper() diff --git a/nucliadb/tests/utils/broker_messages/fields.py b/nucliadb/tests/utils/broker_messages/fields.py index 5b7549103a..7336c59d65 100644 --- a/nucliadb/tests/utils/broker_messages/fields.py +++ b/nucliadb/tests/utils/broker_messages/fields.py @@ -146,10 +146,17 @@ def with_user_entity(self, klass: str, name: str, *, start: int, end: int): ) self._user_metadata.token.append(entity) - def with_extracted_entity(self, klass: str, name: str, *, positions: list[rpb.Position]): - entity = self._extracted_metadata.metadata.metadata.positions[f"{klass}/{name}"] - entity.entity = name - entity.position.extend(positions) + def with_extracted_entity( + self, + klass: str, + name: str, + *, + positions: list[rpb.Position], + data_augmentation_task_id: str = "processor", + ): + self._extracted_metadata.metadata.metadata.entities[data_augmentation_task_id].entities.append( + rpb.FieldEntity(text=name, label=klass, positions=positions) + ) def with_user_paragraph_labels(self, key: str, labelset: str, labels: list[str]): classifications = labels_to_classifications(labelset, labels) diff --git a/nucliadb_models/src/nucliadb_models/extracted.py b/nucliadb_models/src/nucliadb_models/extracted.py index c315dfb4c1..cfe0961526 100644 --- a/nucliadb_models/src/nucliadb_models/extracted.py +++ b/nucliadb_models/src/nucliadb_models/extracted.py @@ -104,10 +104,25 @@ class Positions(BaseModel): entity: str +class FieldEntity(BaseModel): + text: str + label: str + positions: List[Position] + + +class FieldEntities(BaseModel): + """ + Wrapper for the entities extracted from a field (required because protobuf doesn't support lists of lists) + """ + + entities: List[FieldEntity] + + class FieldMetadata(BaseModel): links: List[str] paragraphs: List[Paragraph] - ner: Dict[str, str] + ner: Dict[str, str] # TODO: Remove once processor doesn't use this anymore + entities: Dict[str, FieldEntities] classifications: List[Classification] last_index: Optional[datetime] = None last_understanding: Optional[datetime] = None @@ -116,7 +131,7 @@ class FieldMetadata(BaseModel): thumbnail: Optional[CloudLink] = None language: Optional[str] = None summary: Optional[str] = None - positions: Dict[str, Positions] + positions: Dict[str, Positions] # TODO: Remove once processor doesn't use this anymore relations: Optional[List[Relation]] = None @@ -152,7 +167,7 @@ def shorten_fieldmetadata( cls, message: resources_pb2.FieldComputedMetadata, ) -> None: - large_fields = ["ner", "relations", "positions", "classifications"] + large_fields = ["ner", "relations", "positions", "classifications", "entities"] for field in large_fields: message.metadata.ClearField(field) # type: ignore for metadata in message.split_metadata.values(): @@ -322,6 +337,25 @@ def from_message( def convert_fieldmetadata_pb_to_dict( message: resources_pb2.FieldMetadata, ) -> Dict[str, Any]: + # Backwards compatibility with old entities format + # TODO: Remove once deprecated fields are removed + # If we recieved processor entities in the new field and the old field is empty, we copy them to the old field + if "processor" in message.entities and len(message.positions) == 0 and len(message.ner) == 0: + message.ner.update( + {entity.text: entity.label for entity in message.entities["processor"].entities} + ) + for entity in message.entities["processor"].entities: + message.positions[entity.label + "/" + entity.text].entity = entity.text + message.positions[entity.label + "/" + entity.text].position.extend( + [ + resources_pb2.Position( + start=position.start, + end=position.end, + ) + for position in entity.positions + ] + ) + value = MessageToDict( message, preserving_proto_field_name=True,