Skip to content

Commit

Permalink
Use new entities field in FieldMetadata (#2588)
Browse files Browse the repository at this point in the history
* Use new entities field in FieldMetadata

* lint + format

* move to new field almost everywhere

* remove pdb

* Fix test

* Fix test

* Fixed comments
  • Loading branch information
carlesonielfa authored Oct 31, 2024
1 parent 302bacc commit d9f7e64
Show file tree
Hide file tree
Showing 18 changed files with 286 additions and 57 deletions.
66 changes: 48 additions & 18 deletions nucliadb/src/nucliadb/ingest/orm/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,24 +498,54 @@ def process_field_metadata(
to=relation_node_label,
)
)

for klass_entity, _ in metadata.positions.items():
labels["e"].add(klass_entity)
entity_array = klass_entity.split("/")
if len(entity_array) == 1:
raise AttributeError(f"Entity should be with type {klass_entity}")
elif len(entity_array) > 1:
klass = entity_array[0]
entity = "/".join(entity_array[1:])
relation_node_entity = RelationNode(
value=entity, ntype=RelationNode.NodeType.ENTITY, subtype=klass
)
rel = Relation(
relation=Relation.ENTITY,
source=relation_node_document,
to=relation_node_entity,
)
self.brain.relations.append(rel)
# Data Augmentation + Processor entities
use_legacy_entities = True
for data_augmentation_task_id, entities in metadata.entities.items():
# If we recieved the entities from the processor here, we don't want to use the legacy entities
# TODO: Remove this when processor doesn't use this anymore
if data_augmentation_task_id == "processor":
use_legacy_entities = False

for ent in entities.entities:
entity_text = ent.text
entity_label = ent.label
# Seems like we don't care about where the entity is in the text
# entity_positions = entity.positions
labels["e"].add(
f"{entity_label}/{entity_text}"
) # Add data_augmentation_task_id as a prefix?
relation_node_entity = RelationNode(
value=entity_text,
ntype=RelationNode.NodeType.ENTITY,
subtype=entity_label,
)
rel = Relation(
relation=Relation.ENTITY,
source=relation_node_document,
to=relation_node_entity,
)
self.brain.relations.append(rel)

# Legacy processor entities
# TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
if use_legacy_entities:
for klass_entity, _ in metadata.positions.items():
labels["e"].add(klass_entity)
entity_array = klass_entity.split("/")
if len(entity_array) == 1:
raise AttributeError(f"Entity should be with type {klass_entity}")
elif len(entity_array) > 1:
klass = entity_array[0]
entity = "/".join(entity_array[1:])
relation_node_entity = RelationNode(
value=entity, ntype=RelationNode.NodeType.ENTITY, subtype=klass
)
rel = Relation(
relation=Relation.ENTITY,
source=relation_node_document,
to=relation_node_entity,
)
self.brain.relations.append(rel)

def apply_field_labels(
self,
Expand Down
30 changes: 25 additions & 5 deletions nucliadb/src/nucliadb/ingest/orm/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import logging
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import TYPE_CHECKING, Any, AsyncIterator, Optional, Type
from typing import TYPE_CHECKING, Any, AsyncIterator, MutableMapping, Optional, Type

from nucliadb.common import datamanagers
from nucliadb.common.datamanagers.resources import KB_RESOURCE_SLUG
Expand Down Expand Up @@ -890,7 +890,7 @@ async def iterate_sentences(

entities: dict[str, str] = {}
if enabled_metadata.entities:
entities.update(field_metadata.ner)
_update_entities_dict(entities, field_metadata)

precomputed_vectors = {}
if vo is not None:
Expand Down Expand Up @@ -996,7 +996,7 @@ async def iterate_paragraphs(

entities: dict[str, str] = {}
if enabled_metadata.entities:
entities.update(field_metadata.ner)
_update_entities_dict(entities, field_metadata)

if extracted_text is not None:
if subfield is not None:
Expand Down Expand Up @@ -1075,7 +1075,7 @@ async def iterate_fields(self, enabled_metadata: EnabledMetadata) -> AsyncIterat

if enabled_metadata.entities:
metadata.ClearField("entities")
metadata.entities.update(splitted_metadata.ner)
_update_entities_dict(metadata.entities, splitted_metadata)

pb_field = TrainField()
pb_field.uuid = self.uuid
Expand Down Expand Up @@ -1119,7 +1119,7 @@ async def generate_train_resource(self, enabled_metadata: EnabledMetadata) -> Tr
metadata.labels.field.extend(splitted_metadata.classifications)

if enabled_metadata.entities:
metadata.entities.update(splitted_metadata.ner)
_update_entities_dict(metadata.entities, splitted_metadata)

pb_resource = TrainResource()
pb_resource.uuid = self.uuid
Expand Down Expand Up @@ -1254,3 +1254,23 @@ def extract_field_metadata_languages(
for _, splitted_metadata in field_metadata.metadata.split_metadata.items():
languages.add(splitted_metadata.language)
return list(languages)


def _update_entities_dict(target_entites_dict: MutableMapping[str, str], field_metadata: FieldMetadata):
"""
Update the entities dict with the entities from the field metadata.
Method created to ease the transition from legacy ner field to new entities field.
"""
# Data Augmentation + Processor entities
# This will overwrite entities detected from more than one data augmentation task
# TODO: Change TrainMetadata proto to accept multiple entities with the same text
entity_map = {
entity.text: entity.label
for data_augmentation_task_id, entities_wrapper in field_metadata.entities.items()
for entity in entities_wrapper.entities
}
target_entites_dict.update(entity_map)

# Legacy processor entities
# TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
target_entites_dict.update(field_metadata.ner)
6 changes: 6 additions & 0 deletions nucliadb/src/nucliadb/search/search/chat/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,12 @@ async def _get_ners(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, dict[str,
field = await resource.get_field(fid.key, fid.pb_type, load=False)
fcm = await field.get_field_metadata()
if fcm is not None:
# Data Augmentation + Processor entities
for data_aumgentation_task_id, entities_wrapper in fcm.metadata.entities.items():
for entity in entities_wrapper.entities:
ners.setdefault(entity.label, set()).add(entity.text)
# Legacy processor entities
# TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
for token, family in fcm.metadata.ner.items():
ners.setdefault(family, set()).add(token)
return _id, ners
Expand Down
30 changes: 30 additions & 0 deletions nucliadb/src/nucliadb/train/generators/token_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,36 @@ async def get_field_text(
field_metadata = await field_obj.get_field_metadata()
# Check computed definition of entities
if field_metadata is not None:
# Data Augmentation + Processor entities
for data_augmentation_task_id, entities in field_metadata.metadata.entities.items():
for entity in entities.entities:
entity_text = entity.text
entity_label = entity.label
entity_positions = entity.positions
if entity_label in valid_entity_groups:
split_ners[MAIN].setdefault(entity_label, {}).setdefault(entity_text, [])
for position in entity_positions:
split_ners[MAIN][entity_label][entity_text].append(
(position.start, position.end)
)

for split, split_metadata in field_metadata.split_metadata.items():
for data_augmentation_task_id, entities in split_metadata.entities.items():
for entity in entities.entities:
entity_text = entity.text
entity_label = entity.label
entity_positions = entity.positions
if entity_label in valid_entity_groups:
split_ners.setdefault(split, {}).setdefault(entity_label, {}).setdefault(
entity_text, []
)
for position in entity_positions:
split_ners[split][entity_label][entity_text].append(
(position.start, position.end)
)

# Legacy processor entities
# TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
for entity_key, positions in field_metadata.metadata.positions.items():
entities = entity_key.split("/")
entity_group = entities[0]
Expand Down
28 changes: 24 additions & 4 deletions nucliadb/tests/ingest/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,24 @@ def make_field_metadata(field_id):
ex1.metadata.metadata.last_extract.FromDatetime(datetime.now())
ex1.metadata.metadata.last_summary.FromDatetime(datetime.now())
ex1.metadata.metadata.thumbnail.CopyFrom(THUMBNAIL)
ex1.metadata.metadata.positions["ENTITY/document"].entity = "document"
ex1.metadata.metadata.positions["ENTITY/document"].position.extend(
[rpb.Position(start=0, end=5), rpb.Position(start=13, end=18)]
# Data Augmentation + Processor entities
ex1.metadata.metadata.entities["processor"].entities.extend(
[
rpb.FieldEntity(
text="document",
label="ENTITY",
positions=[rpb.Position(start=0, end=5), rpb.Position(start=13, end=18)],
),
]
)
ex1.metadata.metadata.entities["my-task-id"].entities.extend(
[
rpb.FieldEntity(
text="document",
label="NOUN",
positions=[rpb.Position(start=0, end=5), rpb.Position(start=13, end=18)],
),
]
)
return ex1

Expand Down Expand Up @@ -511,7 +526,12 @@ def broker_resource(
fcm.metadata.metadata.last_index.FromDatetime(datetime.now())
fcm.metadata.metadata.last_understanding.FromDatetime(datetime.now())
fcm.metadata.metadata.last_extract.FromDatetime(datetime.now())
fcm.metadata.metadata.ner["Ramon"] = "PERSON"
fcm.metadata.metadata.entities["processor"].entities.extend(
[rpb.FieldEntity(text="Ramon", label="PERSON")]
)
fcm.metadata.metadata.entities["my-data-augmentation"].entities.extend(
[rpb.FieldEntity(text="Ramon", label="CTO")]
)

c1 = rpb.Classification()
c1.label = "label1"
Expand Down
8 changes: 8 additions & 0 deletions nucliadb/tests/ingest/integration/ingest/test_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
ExtractedVectorsWrapper,
FieldComputedMetadata,
FieldComputedMetadataWrapper,
FieldEntity,
FieldID,
FieldMetadata,
FieldQuestionAnswerWrapper,
Expand Down Expand Up @@ -166,6 +167,13 @@ async def test_ingest_messages_autocommit(kbid: str, processor):
fcm.metadata.metadata.last_index.FromDatetime(datetime.now())
fcm.metadata.metadata.last_understanding.FromDatetime(datetime.now())
fcm.metadata.metadata.last_extract.FromDatetime(datetime.now())
# Data Augmentation + Processor entities
fcm.metadata.metadata.entities["processor"].entities.extend(
[FieldEntity(text="Ramon", label="PERSON")]
)

# Legacy processor entities
# TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
fcm.metadata.metadata.ner["Ramon"] = "PERSON"

c1 = Classification()
Expand Down
23 changes: 22 additions & 1 deletion nucliadb/tests/ingest/integration/ingest/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nucliadb_protos.resources_pb2 import (
Classification,
FieldComputedMetadataWrapper,
FieldEntity,
FieldID,
FieldText,
FieldType,
Expand Down Expand Up @@ -140,6 +141,15 @@ async def test_ingest_field_metadata_relation_extraction(
field="title",
)
)
# Data Augmentation + Processor entities
fcmw.metadata.metadata.entities["my-task-id"].entities.extend(
[
FieldEntity(text="value-3", label="subtype-3"),
FieldEntity(text="value-4", label="subtype-4"),
]
)
# Legacy processor entities
# TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
fcmw.metadata.metadata.positions["subtype-1/value-1"].entity = "value-1"
fcmw.metadata.metadata.positions["subtype-1/value-2"].entity = "value-2"

Expand All @@ -159,7 +169,18 @@ async def test_ingest_field_metadata_relation_extraction(
pb = await storage.get_indexing(index._calls[0][1])

generated_relations = [
# From ner metadata
# From data augmentation + processor metadata
Relation(
relation=Relation.RelationType.ENTITY,
source=RelationNode(value=rid, ntype=RelationNode.NodeType.RESOURCE),
to=RelationNode(value="value-3", ntype=RelationNode.NodeType.ENTITY, subtype="subtype-3"),
),
Relation(
relation=Relation.RelationType.ENTITY,
source=RelationNode(value=rid, ntype=RelationNode.NodeType.RESOURCE),
to=RelationNode(value="value-4", ntype=RelationNode.NodeType.ENTITY, subtype="subtype-4"),
),
# From legacy ner metadata
Relation(
relation=Relation.RelationType.ENTITY,
source=RelationNode(value=rid, ntype=RelationNode.NodeType.RESOURCE),
Expand Down
23 changes: 17 additions & 6 deletions nucliadb/tests/ingest/integration/orm/test_orm_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Classification,
FieldComputedMetadata,
FieldComputedMetadataWrapper,
FieldEntity,
FieldID,
FieldType,
Paragraph,
Expand Down Expand Up @@ -58,13 +59,18 @@ async def test_create_resource_orm_metadata(
p1.classifications.append(cl1)
ex1.metadata.metadata.paragraphs.append(p1)
ex1.metadata.metadata.classifications.append(cl1)
ex1.metadata.metadata.ner["Ramon"] = "PEOPLE"
ex1.metadata.metadata.last_index.FromDatetime(datetime.now())
ex1.metadata.metadata.last_understanding.FromDatetime(datetime.now())
ex1.metadata.metadata.last_extract.FromDatetime(datetime.now())
ex1.metadata.metadata.positions["document"].entity = "Ramon"
ex1.metadata.metadata.positions["document"].position.extend(
[Position(start=0, end=5), Position(start=23, end=28)]
# Data Augmentation + Processor entities
ex1.metadata.metadata.entities["my-task-id"].entities.extend(
[
FieldEntity(
text="Ramon",
label="PEOPLE",
positions=[Position(start=0, end=5), Position(start=23, end=28)],
)
]
)

field_obj: Text = await r.get_field(ex1.field.field, ex1.field.field_type, load=False)
Expand Down Expand Up @@ -97,7 +103,10 @@ async def test_create_resource_orm_metadata_split(
p1.classifications.append(cl1)
ex1.metadata.split_metadata["ff1"].paragraphs.append(p1)
ex1.metadata.split_metadata["ff1"].classifications.append(cl1)
ex1.metadata.split_metadata["ff1"].ner["Ramon"] = "PEOPLE"
ex1.metadata.split_metadata["ff1"].entities["processor"].entities.extend(
[FieldEntity(text="Ramon", label="PERSON")]
)

ex1.metadata.split_metadata["ff1"].last_index.FromDatetime(datetime.now())
ex1.metadata.split_metadata["ff1"].last_understanding.FromDatetime(datetime.now())
ex1.metadata.split_metadata["ff1"].last_extract.FromDatetime(datetime.now())
Expand All @@ -117,7 +126,9 @@ async def test_create_resource_orm_metadata_split(
p1.classifications.append(cl1)
ex2.metadata.split_metadata["ff2"].paragraphs.append(p1)
ex2.metadata.split_metadata["ff2"].classifications.append(cl1)
ex2.metadata.split_metadata["ff2"].ner["Ramon"] = "PEOPLE"
ex1.metadata.split_metadata["ff1"].entities["processor"].entities.extend(
[FieldEntity(text="Ramon", label="PEOPLE")]
)
ex2.metadata.split_metadata["ff2"].last_index.FromDatetime(datetime.now())
ex2.metadata.split_metadata["ff2"].last_understanding.FromDatetime(datetime.now())
ex2.metadata.split_metadata["ff2"].last_extract.FromDatetime(datetime.now())
Expand Down
8 changes: 6 additions & 2 deletions nucliadb/tests/ingest/integration/orm/test_orm_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,9 @@ async def test_generate_broker_message(
lfcm = [fcm for fcm in bm.field_metadata if fcm.field.field == "link1"][0]
assert lfcm.metadata.metadata.links[0] == "https://nuclia.com"
assert len(lfcm.metadata.metadata.paragraphs) == 1
assert len(lfcm.metadata.metadata.positions) == 1
assert len(lfcm.metadata.metadata.entities["processor"].entities) == 1
assert len(lfcm.metadata.metadata.entities["my-task-id"].entities) == 1

assert lfcm.metadata.metadata.HasField("last_index")
assert lfcm.metadata.metadata.HasField("last_understanding")
assert lfcm.metadata.metadata.HasField("last_extract")
Expand Down Expand Up @@ -445,7 +447,9 @@ async def test_generate_index_message_contains_all_metadata(
assert field in fields_to_be_found
fields_to_be_found.remove(field)
assert text.text == "MyText"
assert {"/l/labelset1/label1", "/e/ENTITY/document"}.issubset(set(text.labels))
assert {"/l/labelset1/label1", "/e/ENTITY/document", "/e/NOUN/document"}.issubset(
set(text.labels)
)
if field in ("u/link", "t/text1"):
assert "/e/Location/My home" in text.labels

Expand Down
9 changes: 7 additions & 2 deletions nucliadb/tests/nucliadb/integration/search/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ExtractedTextWrapper,
ExtractedVectorsWrapper,
FieldComputedMetadataWrapper,
FieldEntity,
FieldID,
FieldType,
Paragraph,
Expand Down Expand Up @@ -115,9 +116,13 @@ def broker_message_with_entities(kbid):
fmw = FieldComputedMetadataWrapper()
fmw.field.CopyFrom(field)
family, entity = EntityLabels.DETECTED.split("/")
fmw.metadata.metadata.ner[entity] = family
pos = Position(start=60, end=64)
fmw.metadata.metadata.positions[EntityLabels.DETECTED].position.append(pos)
# Data Augmentation + Processor entities
fmw.metadata.metadata.entities["my-task-id"].entities.extend(
[
FieldEntity(text=entity, label=family, positions=[pos]),
]
)

par1 = Paragraph()
par1.start = 0
Expand Down
Loading

0 comments on commit d9f7e64

Please sign in to comment.