Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CU-8695ucw9b deid transformers fix #490

Merged
merged 8 commits into from
Oct 7, 2024
35 changes: 34 additions & 1 deletion medcat/ner/transformers_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import datasets
from spacy.tokens import Doc
from datetime import datetime
from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple, Callable
from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple, Callable, Type
from spacy.tokens import Span
import inspect
from functools import partial
Expand Down Expand Up @@ -32,6 +32,14 @@
logger = logging.getLogger(__name__)


# generally, the code below catches all exceptions and logs them
# but if we have too many consecutive failures, it might indicate
# an underlying issue that should be reported explicitly so that
# it can be fixed. Otherwise we might end up running incompatible
# models that keep raising exceptions but never explicitly failing
RAISE_AFTER_CONSECUTIVE_IDENTICAL_FAILURES = 10


class TransformersNER(object):
"""TODO: Add documentation"""

Expand Down Expand Up @@ -87,7 +95,13 @@ def create_eval_pipeline(self):
# NOTE: this will fix the DeID model(s) created before medcat 1.9.3
# though this fix may very well be unstable
self.ner_pipe.tokenizer._in_target_context_manager = False
# if not hasattr(self.ner_pipe.tokenizer, 'split_special_tokens'):
# # NOTE: this will fix the DeID model(s) created with transformers before 4.42
# # and allow them to run with later transforemrs
# self.ner_pipe.tokenizer.split_special_tokens = False
self.ner_pipe.device = self.model.device
self._consecutive_identical_failures = 0
self._last_exception: Optional[Tuple[str, Type[Exception]]] = None

def get_hash(self) -> str:
"""A partial hash trying to catch differences between models.
Expand Down Expand Up @@ -416,8 +430,21 @@ def _process(self,
make_pretty_labels(self.cdb, doc, LabelStyle[self.cdb.config.general['make_pretty_labels']])
if self.cdb.config.general['map_cui_to_group'] is not None and self.cdb.addl_info.get('cui2group', {}):
map_ents_to_groups(self.cdb, doc)
self._consecutive_identical_failures = 0 # success
self._last_exception = None
except Exception as e:
logger.warning(e, exc_info=True)
# NOTE: exceptions are rarely 'equal' so if the message and type
# are the same, we consider them the same
ex_info = (str(e), type(e))
if self._last_exception == ex_info:
self._consecutive_identical_failures += 1
else:
self._consecutive_identical_failures = 1
self._last_exception = ex_info
if self._consecutive_identical_failures >= RAISE_AFTER_CONSECUTIVE_IDENTICAL_FAILURES:
cnt = self._consecutive_identical_failures
raise TooManyConsecutiveFailuresException(cnt) from e
yield from docs

# Override
Expand All @@ -439,6 +466,12 @@ def __call__(self, doc: Doc) -> Doc:
return doc


class TooManyConsecutiveFailuresException(Exception):

def __init__(self, cnt: int = RAISE_AFTER_CONSECUTIVE_IDENTICAL_FAILURES) -> None:
super().__init__(f"Got too many ({cnt}) consecutive similar exceptions")


# NOTE: Only needed for datasets backwards compatibility
def func_has_kwarg(func: Callable, keyword: str):
sig = inspect.signature(func)
Expand Down
64 changes: 64 additions & 0 deletions tests/ner/test_transformers_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from spacy.tokens import Doc, Span
from transformers import TrainerCallback
from medcat.ner.transformers_ner import TransformersNER
from medcat.ner.transformers_ner import RAISE_AFTER_CONSECUTIVE_IDENTICAL_FAILURES
from medcat.ner.transformers_ner import TooManyConsecutiveFailuresException
from medcat.config import Config
from medcat.cdb_maker import CDBMaker

Expand Down Expand Up @@ -48,3 +50,65 @@ def on_epoch_end(self, *args, **kwargs) -> None:
assert dataset["train"].num_rows == 48
assert dataset["test"].num_rows == 12
self.assertEqual(tracker.call.call_count, 2)


class FailsAfterTests(unittest.TestCase):
SHOULD_WORK_FOR = RAISE_AFTER_CONSECUTIVE_IDENTICAL_FAILURES - 1

@classmethod
def setUpClass(cls) -> None:
config = Config()
config.general["spacy_model"] = "en_core_web_md"
cdb_maker = CDBMaker(config)
cdb_csv = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "examples", "cdb.csv")
cdb = cdb_maker.prepare_csvs([cdb_csv], full_build=True)
Doc.set_extension("ents", default=[], force=True)
Span.set_extension("confidence", default=-1, force=True)
Span.set_extension("id", default=0, force=True)
Span.set_extension("detected_name", default=None, force=True)
Span.set_extension("link_candidates", default=None, force=True)
Span.set_extension("cui", default=-1, force=True)
Span.set_extension("context_similarity", default=-1, force=True)
cls.undertest = TransformersNER(cdb)
cls.undertest.create_eval_pipeline()

def setUp(self) -> None:
self.undertest.create_eval_pipeline()
self.spacy_doc = English().make_doc(
"\nPatient Name: John Smith\nAddress: 15 Maple Avenue"
"\nCity: New York\nCC: Chronic back pain\n\nHX: Mr. Smith")

def test_runs_correctly_normal(self):
out_doc = self.undertest(self.spacy_doc)
self.assertIs(out_doc, self.spacy_doc)

def _bork_pipe(self):
self.undertest.ner_pipe = None

def _bork_pipe2(self):
def _fake_call(*args, **kwargs) -> Doc:
raise ValueError()
self.undertest.ner_pipe.__call__ = _fake_call

def test_runs_when_borked(self):
self._bork_pipe()
# the exceptions are caught and logged
for cnr in range(self.SHOULD_WORK_FOR):
out_doc = self.undertest(self.spacy_doc)
self.assertIs(out_doc, self.spacy_doc)
self.assertEqual(self.undertest._consecutive_identical_failures, cnr + 1)

def test_no_fail_if_different_exceptions(self):
self._bork_pipe2()
self.undertest(self.spacy_doc)
# this shouldn't raise an exception since the failure is different
# even though the total number is high enough
self.test_runs_when_borked()

def test_raises_upon_consecutive_fails(self):
self._bork_pipe()
# runs the ones that should work anyway (but not 1 more!)
self.test_runs_when_borked()
self.assertEqual(self.undertest._consecutive_identical_failures, self.SHOULD_WORK_FOR)
with self.assertRaises(TooManyConsecutiveFailuresException):
self.undertest(self.spacy_doc)
Loading