diff --git a/medcat/ner/transformers_ner.py b/medcat/ner/transformers_ner.py index 32eb2352..1cc4c5f2 100644 --- a/medcat/ner/transformers_ner.py +++ b/medcat/ner/transformers_ner.py @@ -4,7 +4,7 @@ import datasets from spacy.tokens import Doc from datetime import datetime -from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple, Callable +from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple, Callable, Type from spacy.tokens import Span import inspect from functools import partial @@ -87,7 +87,13 @@ def create_eval_pipeline(self): # NOTE: this will fix the DeID model(s) created before medcat 1.9.3 # though this fix may very well be unstable self.ner_pipe.tokenizer._in_target_context_manager = False + # if not hasattr(self.ner_pipe.tokenizer, 'split_special_tokens'): + # # NOTE: this will fix the DeID model(s) created with transformers before 4.42 + # # and allow them to run with later transforemrs + # self.ner_pipe.tokenizer.split_special_tokens = False self.ner_pipe.device = self.model.device + self._consecutive_identical_failures = 0 + self._last_exception: Optional[Tuple[str, Type[Exception]]] = None def get_hash(self) -> str: """A partial hash trying to catch differences between models. @@ -390,34 +396,33 @@ def _process(self, #all_text_processed = self.tokenizer.encode_eval(all_text) # For now we will process the documents one by one, should be improved in the future to use batching for doc in docs: - try: - res = self.ner_pipe(doc.text, aggregation_strategy=self.config.general['ner_aggregation_strategy']) - doc.ents = [] # type: ignore - for r in res: - inds = [] - for ind, word in enumerate(doc): - end_char = word.idx + len(word.text) - if end_char <= r['end'] and end_char > r['start']: - inds.append(ind) - # To not loop through everything - if end_char > r['end']: - break - if inds: - entity = Span(doc, min(inds), max(inds) + 1, label=r['entity_group']) - entity._.cui = r['entity_group'] - entity._.context_similarity = r['score'] - entity._.detected_name = r['word'] - entity._.id = len(doc._.ents) - entity._.confidence = r['score'] - - doc._.ents.append(entity) - create_main_ann(self.cdb, doc) - if self.cdb.config.general['make_pretty_labels'] is not None: - make_pretty_labels(self.cdb, doc, LabelStyle[self.cdb.config.general['make_pretty_labels']]) - if self.cdb.config.general['map_cui_to_group'] is not None and self.cdb.addl_info.get('cui2group', {}): - map_ents_to_groups(self.cdb, doc) - except Exception as e: - logger.warning(e, exc_info=True) + res = self.ner_pipe(doc.text, aggregation_strategy=self.config.general['ner_aggregation_strategy']) + doc.ents = [] # type: ignore + for r in res: + inds = [] + for ind, word in enumerate(doc): + end_char = word.idx + len(word.text) + if end_char <= r['end'] and end_char > r['start']: + inds.append(ind) + # To not loop through everything + if end_char > r['end']: + break + if inds: + entity = Span(doc, min(inds), max(inds) + 1, label=r['entity_group']) + entity._.cui = r['entity_group'] + entity._.context_similarity = r['score'] + entity._.detected_name = r['word'] + entity._.id = len(doc._.ents) + entity._.confidence = r['score'] + + doc._.ents.append(entity) + create_main_ann(self.cdb, doc) + if self.cdb.config.general['make_pretty_labels'] is not None: + make_pretty_labels(self.cdb, doc, LabelStyle[self.cdb.config.general['make_pretty_labels']]) + if self.cdb.config.general['map_cui_to_group'] is not None and self.cdb.addl_info.get('cui2group', {}): + map_ents_to_groups(self.cdb, doc) + self._consecutive_identical_failures = 0 # success + self._last_exception = None yield from docs # Override