From 44db08b0266846497ead3037b2ac5b94c073c3d4 Mon Sep 17 00:00:00 2001 From: Mart Ratas Date: Mon, 7 Oct 2024 10:37:55 +0100 Subject: [PATCH] CU-8695ucw9b deid transformers fix (#490) * CU-8695ucw9b: Fix older DeID models due to changes in transformers. Since transformers 4.42.0, the tokenizer is expected to have the 'split_special_tokens' attribute. But the version we've saved does not. So when it's loaded, this causes an exception to be raised (which is currently caught and logged by medcat). * CU-8695ucw9b: Add functionality for transformers NER to spectacularly fail upon consistent consecutive exceptions. The idea is that this way, if something in the underlying models is consistently failing, the exception is raised rather than simply logged * CU-8695ucw9b: Add tests for exception raising after a pre-defined number of failed document processes * CU-8695ucw9b: Change conditions for raising exception on consecutive failure. Now only raise the exception if the consecutive failure is identical (or similar). We determine that from the type and string-representation of the exception being raised. * CU-8695ucw9b: Small additional cleanup on successful TNER processing * CU-8695ucw9b: Use custom exception when failing due to consecutive exceptions * CU-8695ucw9b: Remove try-except when processing transformers NER to force immediate raising of exception --- medcat/ner/transformers_ner.py | 63 ++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/medcat/ner/transformers_ner.py b/medcat/ner/transformers_ner.py index 32eb2352..1cc4c5f2 100644 --- a/medcat/ner/transformers_ner.py +++ b/medcat/ner/transformers_ner.py @@ -4,7 +4,7 @@ import datasets from spacy.tokens import Doc from datetime import datetime -from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple, Callable +from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple, Callable, Type from spacy.tokens import Span import inspect from functools import partial @@ -87,7 +87,13 @@ def create_eval_pipeline(self): # NOTE: this will fix the DeID model(s) created before medcat 1.9.3 # though this fix may very well be unstable self.ner_pipe.tokenizer._in_target_context_manager = False + # if not hasattr(self.ner_pipe.tokenizer, 'split_special_tokens'): + # # NOTE: this will fix the DeID model(s) created with transformers before 4.42 + # # and allow them to run with later transforemrs + # self.ner_pipe.tokenizer.split_special_tokens = False self.ner_pipe.device = self.model.device + self._consecutive_identical_failures = 0 + self._last_exception: Optional[Tuple[str, Type[Exception]]] = None def get_hash(self) -> str: """A partial hash trying to catch differences between models. @@ -390,34 +396,33 @@ def _process(self, #all_text_processed = self.tokenizer.encode_eval(all_text) # For now we will process the documents one by one, should be improved in the future to use batching for doc in docs: - try: - res = self.ner_pipe(doc.text, aggregation_strategy=self.config.general['ner_aggregation_strategy']) - doc.ents = [] # type: ignore - for r in res: - inds = [] - for ind, word in enumerate(doc): - end_char = word.idx + len(word.text) - if end_char <= r['end'] and end_char > r['start']: - inds.append(ind) - # To not loop through everything - if end_char > r['end']: - break - if inds: - entity = Span(doc, min(inds), max(inds) + 1, label=r['entity_group']) - entity._.cui = r['entity_group'] - entity._.context_similarity = r['score'] - entity._.detected_name = r['word'] - entity._.id = len(doc._.ents) - entity._.confidence = r['score'] - - doc._.ents.append(entity) - create_main_ann(self.cdb, doc) - if self.cdb.config.general['make_pretty_labels'] is not None: - make_pretty_labels(self.cdb, doc, LabelStyle[self.cdb.config.general['make_pretty_labels']]) - if self.cdb.config.general['map_cui_to_group'] is not None and self.cdb.addl_info.get('cui2group', {}): - map_ents_to_groups(self.cdb, doc) - except Exception as e: - logger.warning(e, exc_info=True) + res = self.ner_pipe(doc.text, aggregation_strategy=self.config.general['ner_aggregation_strategy']) + doc.ents = [] # type: ignore + for r in res: + inds = [] + for ind, word in enumerate(doc): + end_char = word.idx + len(word.text) + if end_char <= r['end'] and end_char > r['start']: + inds.append(ind) + # To not loop through everything + if end_char > r['end']: + break + if inds: + entity = Span(doc, min(inds), max(inds) + 1, label=r['entity_group']) + entity._.cui = r['entity_group'] + entity._.context_similarity = r['score'] + entity._.detected_name = r['word'] + entity._.id = len(doc._.ents) + entity._.confidence = r['score'] + + doc._.ents.append(entity) + create_main_ann(self.cdb, doc) + if self.cdb.config.general['make_pretty_labels'] is not None: + make_pretty_labels(self.cdb, doc, LabelStyle[self.cdb.config.general['make_pretty_labels']]) + if self.cdb.config.general['map_cui_to_group'] is not None and self.cdb.addl_info.get('cui2group', {}): + map_ents_to_groups(self.cdb, doc) + self._consecutive_identical_failures = 0 # success + self._last_exception = None yield from docs # Override