From 3b2c328cc3b54f8bb3a9775099d9b7c40c425822 Mon Sep 17 00:00:00 2001 From: jerry genser Date: Mon, 22 Jan 2024 11:42:34 -0500 Subject: [PATCH 1/4] add: ability to predict on other spangroups --- medcat/config_meta_cat.py | 1 + medcat/meta_cat.py | 24 ++++++++++++++---------- tests/test_meta_cat.py | 32 ++++++++++++++++++++++++++++++-- 3 files changed, 45 insertions(+), 12 deletions(-) diff --git a/medcat/config_meta_cat.py b/medcat/config_meta_cat.py index ae3e82ef8..1731cf610 100644 --- a/medcat/config_meta_cat.py +++ b/medcat/config_meta_cat.py @@ -37,6 +37,7 @@ class General(MixingConfig, BaseModel): a deployment.""" pipe_batch_size_in_chars: int = 20000000 """How many characters are piped at once into the meta_cat class""" + span_group: Optional[str] = None class Config: extra = Extra.allow diff --git a/medcat/meta_cat.py b/medcat/meta_cat.py index d92e6ea61..7f9615b56 100644 --- a/medcat/meta_cat.py +++ b/medcat/meta_cat.py @@ -5,7 +5,7 @@ import numpy from multiprocessing import Lock from torch import nn, Tensor -from spacy.tokens import Doc +from spacy.tokens import Doc, Span from datetime import datetime from typing import Iterable, Iterator, Optional, Dict, List, Tuple, cast, Union from medcat.utils.hasher import Hasher @@ -356,6 +356,17 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA meta_cat.model.load_state_dict(torch.load(model_save_path, map_location=device)) return meta_cat + + def get_ents(self, doc: Doc) -> List[Span]: + span_group_name = self.config.general.span_group + if span_group_name: + return doc.spans[span_group_name] + + # Should we annotate overlapping entities + if self.config.general['annotate_overlapping']: + return doc._.ents + + return doc.ents def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowercase: bool) -> Tuple: """Prepares document. @@ -381,11 +392,7 @@ def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowe cntx_right = config.general['cntx_right'] replace_center = config.general['replace_center'] - # Should we annotate overlapping entities - if config.general['annotate_overlapping']: - ents = doc._.ents - else: - ents = doc.ents + ents = self.get_ents(doc) samples = [] last_ind = 0 @@ -522,10 +529,7 @@ def _set_meta_anns(self, predictions = all_predictions[start_ind:end_ind] confidences = all_confidences[start_ind:end_ind] - if config.general['annotate_overlapping']: - ents = doc._.ents - else: - ents = doc.ents + ents = self.get_ents(doc) for ent in ents: ent_ind = ent_id2ind[ent._.id] diff --git a/tests/test_meta_cat.py b/tests/test_meta_cat.py index df5be9f77..ac332a81a 100644 --- a/tests/test_meta_cat.py +++ b/tests/test_meta_cat.py @@ -7,7 +7,8 @@ from medcat.meta_cat import MetaCAT from medcat.config_meta_cat import ConfigMetaCAT from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT - +import spacy +from spacy.tokens import Span class MetaCATTests(unittest.TestCase): @@ -19,7 +20,7 @@ def setUpClass(cls) -> None: config.train['nepochs'] = 1 config.model['input_size'] = 100 - cls.meta_cat = MetaCAT(tokenizer=tokenizer, embeddings=None, config=config) + cls.meta_cat: MetaCAT = MetaCAT(tokenizer=tokenizer, embeddings=None, config=config) cls.tmp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp") os.makedirs(cls.tmp_dir, exist_ok=True) @@ -44,6 +45,33 @@ def test_save_load(self): self.assertEqual(f1, n_f1) + def test_predict_spangroup(self): + Span.set_extension('id', default=0, force=True) + Span.set_extension('meta_anns', default=None, force=True) + + + json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', 'mct_export_for_meta_cat_test.json') + self.meta_cat.train(json_path, save_dir_path=self.tmp_dir) + self.meta_cat.save(self.tmp_dir) + n_meta_cat = MetaCAT.load(self.tmp_dir) + assert n_meta_cat.config.general.span_group is None + + spangroup_name = 'predict_spangroup' + n_meta_cat.config.general.span_group = spangroup_name + nlp = spacy.blank("en") + doc = nlp("No history of diabetes.") + span = doc.char_span(14, 22, label="foo_spantype") + assert span.text == 'diabetes' + doc.spans[spangroup_name] = [span] + doc = n_meta_cat(doc) + + # set back to None + n_meta_cat.config.general.span_group = None + assert doc.spans[spangroup_name][0]._.meta_anns['Status']['value'] == 'Affirmed' + + + + if __name__ == '__main__': unittest.main() From ed7d653fa099ceb6480f77aa0c5e82d631564c7f Mon Sep 17 00:00:00 2001 From: jerry genser Date: Tue, 23 Jan 2024 08:57:54 -0500 Subject: [PATCH 2/4] add: pr comments and better error --- medcat/config_meta_cat.py | 2 ++ medcat/meta_cat.py | 11 +++++++---- tests/test_meta_cat.py | 41 +++++++++++++++++++++++++++------------ 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/medcat/config_meta_cat.py b/medcat/config_meta_cat.py index 1731cf610..2aacbc33b 100644 --- a/medcat/config_meta_cat.py +++ b/medcat/config_meta_cat.py @@ -38,6 +38,8 @@ class General(MixingConfig, BaseModel): pipe_batch_size_in_chars: int = 20000000 """How many characters are piped at once into the meta_cat class""" span_group: Optional[str] = None + """If set, the spacy span group that the metacat model will assign annotaitons. + Otherwise defaults to doc._.ents or doc.ents per the annotate_overlapping settings""" class Config: extra = Extra.allow diff --git a/medcat/meta_cat.py b/medcat/meta_cat.py index 7f9615b56..21bdced7a 100644 --- a/medcat/meta_cat.py +++ b/medcat/meta_cat.py @@ -357,10 +357,13 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA return meta_cat - def get_ents(self, doc: Doc) -> List[Span]: - span_group_name = self.config.general.span_group - if span_group_name: - return doc.spans[span_group_name] + def get_ents(self, doc: Doc) -> Iterable[Span]: + spangroup_name = self.config.general.span_group + if spangroup_name: + try: + return doc.spans[spangroup_name] + except KeyError: + raise Exception(f"Configuration error MetaCAT was configured to set meta_anns on {spangroup_name} but this spangroup was not set on the doc.") # Should we annotate overlapping entities if self.config.general['annotate_overlapping']: diff --git a/tests/test_meta_cat.py b/tests/test_meta_cat.py index ac332a81a..8cd444668 100644 --- a/tests/test_meta_cat.py +++ b/tests/test_meta_cat.py @@ -45,32 +45,49 @@ def test_save_load(self): self.assertEqual(f1, n_f1) - def test_predict_spangroup(self): + def _prepare_doc_w_spangroup(self, spangroup_name: str): + """ + Create spans under an arbitrary spangroup key + """ Span.set_extension('id', default=0, force=True) Span.set_extension('meta_anns', default=None, force=True) + nlp = spacy.blank("en") + doc = nlp("Pt has diabetes and copd.") + span_0 = doc.char_span(7,15, label="diabetes") + assert span_0.text == 'diabetes' + span_1 = doc.char_span(20,24, label="copd") + assert span_1.text == 'copd' + doc.spans[spangroup_name] = [span_0, span_1] + return doc + def test_predict_spangroup(self): json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', 'mct_export_for_meta_cat_test.json') self.meta_cat.train(json_path, save_dir_path=self.tmp_dir) self.meta_cat.save(self.tmp_dir) n_meta_cat = MetaCAT.load(self.tmp_dir) - assert n_meta_cat.config.general.span_group is None - spangroup_name = 'predict_spangroup' + spangroup_name = "mock_span_group" n_meta_cat.config.general.span_group = spangroup_name - nlp = spacy.blank("en") - doc = nlp("No history of diabetes.") - span = doc.char_span(14, 22, label="foo_spantype") - assert span.text == 'diabetes' - doc.spans[spangroup_name] = [span] - doc = n_meta_cat(doc) - # set back to None - n_meta_cat.config.general.span_group = None - assert doc.spans[spangroup_name][0]._.meta_anns['Status']['value'] == 'Affirmed' + doc = self._prepare_doc_w_spangroup(spangroup_name) + doc = n_meta_cat(doc) + spans = doc.spans[spangroup_name] + self.assertEqual(len(spans), 2) + # All spans are annotate + for span in spans: + self.assertEqual(span._.meta_anns['Status']['value'], "Affirmed") + # Informative error if spangroup is not set + doc = self._prepare_doc_w_spangroup("foo") + n_meta_cat.config.general.span_group = "bar" + try: + doc = n_meta_cat(doc) + except Exception as error: + self.assertIn("Configuration error", str(error)) + n_meta_cat.config.general.span_group = None if __name__ == '__main__': From 5706016414baeb527a2fded980d0d75a77f5c575 Mon Sep 17 00:00:00 2001 From: jerry genser Date: Tue, 23 Jan 2024 09:00:07 -0500 Subject: [PATCH 3/4] fix: typo --- medcat/config_meta_cat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/medcat/config_meta_cat.py b/medcat/config_meta_cat.py index 2aacbc33b..47f42dc28 100644 --- a/medcat/config_meta_cat.py +++ b/medcat/config_meta_cat.py @@ -38,7 +38,7 @@ class General(MixingConfig, BaseModel): pipe_batch_size_in_chars: int = 20000000 """How many characters are piped at once into the meta_cat class""" span_group: Optional[str] = None - """If set, the spacy span group that the metacat model will assign annotaitons. + """If set, the spacy span group that the metacat model will assign annotations. Otherwise defaults to doc._.ents or doc.ents per the annotate_overlapping settings""" class Config: From fb51620b62e536d97f99ba2c190b27374df8541e Mon Sep 17 00:00:00 2001 From: jerry genser Date: Mon, 29 Jan 2024 15:06:07 -0500 Subject: [PATCH 4/4] fix: linting --- medcat/meta_cat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/medcat/meta_cat.py b/medcat/meta_cat.py index 21bdced7a..bf7f09709 100644 --- a/medcat/meta_cat.py +++ b/medcat/meta_cat.py @@ -356,7 +356,7 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA meta_cat.model.load_state_dict(torch.load(model_save_path, map_location=device)) return meta_cat - + def get_ents(self, doc: Doc) -> Iterable[Span]: spangroup_name = self.config.general.span_group if spangroup_name: