From 8dcafc5ce7eceb0590ccea9b4d301d3ba13cb816 Mon Sep 17 00:00:00 2001 From: Kenneth Enevoldsen Date: Tue, 19 Sep 2023 14:56:39 +0200 Subject: [PATCH] feat: keep span annotations for ents --- src/augmenty/span/entities.py | 80 ++++++++++++++++++++++++++++----- src/augmenty/span/utils.py | 39 ++++++++++++++++ tests/test_spans.py | 83 +++++++++++++++++++++++++++++++++++ 3 files changed, 192 insertions(+), 10 deletions(-) create mode 100644 src/augmenty/span/utils.py diff --git a/src/augmenty/span/entities.py b/src/augmenty/span/entities.py index 9a12b788..7120549f 100644 --- a/src/augmenty/span/entities.py +++ b/src/augmenty/span/entities.py @@ -18,13 +18,26 @@ from spacy.training import Example from spacy.util import registry +from augmenty import span + from ..augment_utilities import make_text_from_orth +from .utils import offset_range # create entity type ENTITY = Union[str, List[str], Span, Doc] -def __normalize_entity(entity: ENTITY, nlp: Language) -> Dict[str, List[Any]]: +def _spacing_to_str(spacing: Union[List[str], List[bool]]) -> List[str]: + def to_string(x: Union[str, bool]) -> str: + if isinstance(x, str): + return x + else: + return " " if x else "" + + return [to_string(x) for x in spacing] + + +def __normalize_entity(entity: ENTITY, nlp: Language) -> Dict[str, Any]: spacy = None pos = None tag = None @@ -50,7 +63,7 @@ def __normalize_entity(entity: ENTITY, nlp: Language) -> Dict[str, List[Any]]: ) # if not specifed use default values if spacy is None: - spacy = [True] * len(orth) + spacy = [" "] * len(orth) if pos is None: pos = ["PROPN"] * len(orth) if tag is None: @@ -60,6 +73,12 @@ def __normalize_entity(entity: ENTITY, nlp: Language) -> Dict[str, List[Any]]: if lemma is None: lemma = orth + _spacy = _spacing_to_str(spacy) + str_repr = "" + for e, s in zip(orth[:-1], _spacy[:-1]): + str_repr += e + s + str_repr += orth[-1] + return { "ORTH": orth, "SPACY": spacy, @@ -67,9 +86,34 @@ def __normalize_entity(entity: ENTITY, nlp: Language) -> Dict[str, List[Any]]: "TAG": tag, "MORPH": morph, "LEMMA": lemma, + "STR": str_repr, } +def _update_span_annotations( + span_anno: Dict[str, list], + ent: Span, + offset: int, + entity_offset: int, +) -> Dict[str, list]: + """Update the span annotations to be in line with the new doc.""" + ent_range = (ent.start + offset, ent.end + offset) + + for anno_key, spans in span_anno.items(): + new_spans = [] + for span_start, span_end, _, __ in spans: + span_start, span_end = offset_range( + current_range=(span_start, span_end), + inserted_range=ent_range, + offset=entity_offset, + ) + new_spans.append((span_start, span_end, _, __)) + + span_anno[anno_key] = new_spans + + return span_anno + + def ent_augmenter_v1( nlp: Language, example: Example, @@ -82,10 +126,14 @@ def ent_augmenter_v1( example_dict = example.to_dict() offset = 0 + str_offset = 0 + spans_anno = example_dict["doc_annotation"]["spans"] tok_anno = example_dict["token_annotation"] ents = example_dict["doc_annotation"]["entities"] - if example.y.has_annotation("HEAD") and resolve_dependencies: + + should_update_heads = example.y.has_annotation("HEAD") and resolve_dependencies + if should_update_heads: head = np.array(tok_anno["HEAD"]) for ent in example.y.ents: @@ -105,10 +153,13 @@ def ent_augmenter_v1( normalized_ent = __normalize_entity(new_ent, nlp) new_ent = normalized_ent["ORTH"] spacing = normalized_ent["SPACY"] + str_ent = normalized_ent["STR"] # Handle token annotations len_ent = len(new_ent) - i = slice(ent.start + offset, ent.end + offset) + str_len_ent = len(str_ent) + ent_range = (ent.start + offset, ent.end + offset) + i = slice(*ent_range) tok_anno["ORTH"][i] = new_ent tok_anno["LEMMA"][i] = normalized_ent["LEMMA"] @@ -125,11 +176,12 @@ def ent_augmenter_v1( spacing[-1:] = [ent[-1].whitespace_] tok_anno["SPACY"][i] = spacing - offset_ = len_ent - (ent.end - ent.start) - if example.y.has_annotation("HEAD") and resolve_dependencies: + entity_offset = len_ent - (ent.end - ent.start) + entity_str_offset = str_len_ent - len(ent.text) + if should_update_heads: # Handle HEAD - head[head > ent.start + offset] += offset_ + head[head > ent.start + offset] += entity_offset # keep first head correcting for changing entity size, set rest to # refer to index of first name head = np.concatenate( @@ -142,7 +194,15 @@ def ent_augmenter_v1( np.array(head[ent.end + offset :]), # after ], ) - offset += offset_ + + spans_anno = _update_span_annotations( + spans_anno, + ent, + str_offset, + entity_str_offset, + ) + offset += entity_offset + str_offset += entity_str_offset # Handle entities IOB tags if len_ent == 1: @@ -154,8 +214,8 @@ def ent_augmenter_v1( + ["L-" + ent.label_] ) - if example.y.has_annotation("HEAD") and resolve_dependencies: - tok_anno["HEAD"] = head.tolist() + if should_update_heads: + tok_anno["HEAD"] = head.tolist() # type: ignore else: tok_anno["HEAD"] = list(range(len(tok_anno["ORTH"]))) diff --git a/src/augmenty/span/utils.py b/src/augmenty/span/utils.py new file mode 100644 index 00000000..761ccc64 --- /dev/null +++ b/src/augmenty/span/utils.py @@ -0,0 +1,39 @@ +from typing import Tuple + + +def offset_range( + current_range: Tuple[int, int], + inserted_range: Tuple[int, int], + offset: int, +) -> Tuple[int, int]: + """Update current range based on inserted range and previous range. + + Args: + current_range: The range you wish the indices to be updated for. + inserted_range: The range of the inserted range. + offset: The offset to apply to the current range. + """ + + start, end = current_range + + if offset == 0: + return current_range + + is_within_range = ( + inserted_range[0] <= start <= inserted_range[1] + or inserted_range[0] <= end <= inserted_range[1] + ) + if is_within_range: + return start, end + offset + + is_before_range = start < inserted_range[0] + if is_before_range: + return start, end + + is_after_range = end > inserted_range[1] + if is_after_range: + return start + offset, end + offset + + raise ValueError( + f"Current range {current_range} is not within inserted range {inserted_range}", + ) diff --git a/tests/test_spans.py b/tests/test_spans.py index 38e5394e..184a70c4 100644 --- a/tests/test_spans.py +++ b/tests/test_spans.py @@ -1,3 +1,7 @@ +from typing import Callable + +import pytest +from spacy.language import Language from spacy.tokens import Doc import augmenty @@ -5,6 +9,85 @@ from .fixtures import nlp_en, nlp_en_md # noqa +@pytest.fixture +def doc(nlp_en: Language) -> Doc: # noqa + doc = Doc( + nlp_en.vocab, + words=[ + "Augmenty", + "is", + "a", + "wonderful", + "tool", + "for", + "augmentation", + ".", + ], + spaces=[True] * 6 + [False] * 2, + ents=["B-ORG"] + ["O"] * 7, + ) + return doc + + +@pytest.fixture +def ent_augmenter(): + ent_augmenter = augmenty.load( + "ents_replace_v1", # type: ignore + level=1.00, + ent_dict={"ORG": [["SpaCy"]]}, + ) + return ent_augmenter + + +@pytest.mark.parametrize( + "nlp", + [ + pytest.lazy_fixture("nlp_en"), + pytest.lazy_fixture("nlp_en_md"), + ], +) +def test_ent_replace_with_span_annotations( + doc: Doc, + ent_augmenter: Callable, + nlp: Language, +): + # add span annotations + positive_noun_chunks = [doc[3:5]] + is_augmenty = [doc[0:1]] + doc.spans["positive_noun_chunks"] = positive_noun_chunks + doc.spans["is_augmenty"] = is_augmenty + + docs = list(augmenty.docs([doc], augmenter=ent_augmenter, nlp=nlp)) + + # Check spans + doc_pos_noun_chunks = docs[0].spans["positive_noun_chunks"] + assert doc_pos_noun_chunks[0].text == "wonderful tool", "the span is not maintained" + + doc_is_augmenty = docs[0].spans["is_augmenty"] + assert doc_is_augmenty[0].text == "SpaCy", "the span is not maintained" + + +@pytest.mark.parametrize( + "nlp", + [ + pytest.lazy_fixture("nlp_en"), + pytest.lazy_fixture("nlp_en_md"), + ], +) +def test_ent_replace_with_cats_annotations( + doc: Doc, + ent_augmenter: Callable, + nlp: Language, +): + # add doc annotations + doc.cats["is_positive"] = 1 + + # augment + docs = list(augmenty.docs([doc], augmenter=ent_augmenter, nlp=nlp)) + + assert docs[0].cats["is_positive"] == 1.0, "the document category is not maintained" + + def test_create_ent_replace(nlp_en_md, nlp_en): # noqa F811 doc = Doc( nlp_en.vocab,