Skip to content

Commit

Permalink
CU-8695ucw9b deid transformers fix (#490)
Browse files Browse the repository at this point in the history
* CU-8695ucw9b: Fix older DeID models due to changes in transformers.

Since transformers 4.42.0, the tokenizer is expected to have the 'split_special_tokens' attribute. But the version we've saved does not. So when it's loaded, this causes an exception to be raised (which is currently caught and logged by medcat).

* CU-8695ucw9b: Add functionality for transformers NER to spectacularly fail upon consistent consecutive exceptions.

The idea is that this way, if something in the underlying models is consistently failing, the exception is raised rather than simply logged

* CU-8695ucw9b: Add tests for exception raising after a pre-defined number of failed document processes

* CU-8695ucw9b: Change conditions for raising exception on consecutive failure.

Now only raise the exception if the consecutive failure is identical (or similar). We determine that from the type and string-representation of the exception being raised.

* CU-8695ucw9b: Small additional cleanup on successful TNER processing

* CU-8695ucw9b: Use custom exception when failing due to consecutive exceptions

* CU-8695ucw9b: Remove try-except when processing transformers NER to force immediate raising of exception
  • Loading branch information
mart-r authored Oct 7, 2024
1 parent b433195 commit 44db08b
Showing 1 changed file with 34 additions and 29 deletions.
63 changes: 34 additions & 29 deletions medcat/ner/transformers_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import datasets
from spacy.tokens import Doc
from datetime import datetime
from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple, Callable
from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple, Callable, Type
from spacy.tokens import Span
import inspect
from functools import partial
Expand Down Expand Up @@ -87,7 +87,13 @@ def create_eval_pipeline(self):
# NOTE: this will fix the DeID model(s) created before medcat 1.9.3
# though this fix may very well be unstable
self.ner_pipe.tokenizer._in_target_context_manager = False
# if not hasattr(self.ner_pipe.tokenizer, 'split_special_tokens'):
# # NOTE: this will fix the DeID model(s) created with transformers before 4.42
# # and allow them to run with later transforemrs
# self.ner_pipe.tokenizer.split_special_tokens = False
self.ner_pipe.device = self.model.device
self._consecutive_identical_failures = 0
self._last_exception: Optional[Tuple[str, Type[Exception]]] = None

def get_hash(self) -> str:
"""A partial hash trying to catch differences between models.
Expand Down Expand Up @@ -390,34 +396,33 @@ def _process(self,
#all_text_processed = self.tokenizer.encode_eval(all_text)
# For now we will process the documents one by one, should be improved in the future to use batching
for doc in docs:
try:
res = self.ner_pipe(doc.text, aggregation_strategy=self.config.general['ner_aggregation_strategy'])
doc.ents = [] # type: ignore
for r in res:
inds = []
for ind, word in enumerate(doc):
end_char = word.idx + len(word.text)
if end_char <= r['end'] and end_char > r['start']:
inds.append(ind)
# To not loop through everything
if end_char > r['end']:
break
if inds:
entity = Span(doc, min(inds), max(inds) + 1, label=r['entity_group'])
entity._.cui = r['entity_group']
entity._.context_similarity = r['score']
entity._.detected_name = r['word']
entity._.id = len(doc._.ents)
entity._.confidence = r['score']

doc._.ents.append(entity)
create_main_ann(self.cdb, doc)
if self.cdb.config.general['make_pretty_labels'] is not None:
make_pretty_labels(self.cdb, doc, LabelStyle[self.cdb.config.general['make_pretty_labels']])
if self.cdb.config.general['map_cui_to_group'] is not None and self.cdb.addl_info.get('cui2group', {}):
map_ents_to_groups(self.cdb, doc)
except Exception as e:
logger.warning(e, exc_info=True)
res = self.ner_pipe(doc.text, aggregation_strategy=self.config.general['ner_aggregation_strategy'])
doc.ents = [] # type: ignore
for r in res:
inds = []
for ind, word in enumerate(doc):
end_char = word.idx + len(word.text)
if end_char <= r['end'] and end_char > r['start']:
inds.append(ind)
# To not loop through everything
if end_char > r['end']:
break
if inds:
entity = Span(doc, min(inds), max(inds) + 1, label=r['entity_group'])
entity._.cui = r['entity_group']
entity._.context_similarity = r['score']
entity._.detected_name = r['word']
entity._.id = len(doc._.ents)
entity._.confidence = r['score']

doc._.ents.append(entity)
create_main_ann(self.cdb, doc)
if self.cdb.config.general['make_pretty_labels'] is not None:
make_pretty_labels(self.cdb, doc, LabelStyle[self.cdb.config.general['make_pretty_labels']])
if self.cdb.config.general['map_cui_to_group'] is not None and self.cdb.addl_info.get('cui2group', {}):
map_ents_to_groups(self.cdb, doc)
self._consecutive_identical_failures = 0 # success
self._last_exception = None
yield from docs

# Override
Expand Down

0 comments on commit 44db08b

Please sign in to comment.