Skip to content

Commit

Permalink
CU-2e77a31: Fix a bunch of typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
mart-r committed Oct 30, 2023
1 parent dd13abb commit c6579be
Show file tree
Hide file tree
Showing 13 changed files with 52 additions and 42 deletions.
8 changes: 8 additions & 0 deletions medcat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,11 @@ def multiprocessing_pipe(self,
Union[List[Tuple], Dict]:
{id: doc_json, id: doc_json, ...} or if return_dict is False, a list of tuples: [(id, doc_json), (id, doc_json), ...]
"""

@staticmethod
def _get_doc_annotations(doc: dict):
if type(doc['annotations']) == list: # type: ignore
return doc['annotations'] # type: ignore
if type(doc['annotations']) == dict: # type: ignore
return doc['annotations'].values() # type: ignore
return None
8 changes: 0 additions & 8 deletions medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,13 +1553,5 @@ def _pipe_error_handler(proc_name: str, proc: "Pipe", docs: List[Doc], e: Except
if hasattr(doc, "text"):
logger.warning("%s...", doc.text[:50])

@staticmethod
def _get_doc_annotations(doc: Doc):
if type(doc['annotations']) == list: # type: ignore
return doc['annotations'] # type: ignore
if type(doc['annotations']) == dict: # type: ignore
return doc['annotations'].values() # type: ignore
return None

def destroy_pipe(self):
self.pipe.destroy()
2 changes: 1 addition & 1 deletion medcat/cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def load(cls, path: str, json_path: Optional[str] = None, config_dict: Optional[

return cdb

def import_training(self, cdb: "CDB", overwrite: bool = True) -> None:
def import_training(self, cdb: CDBBase, overwrite: bool = True) -> None:
"""This will import vector embeddings from another CDB. No new concepts will be added.
IMPORTANT it will not import name maps (cui2names, name2cuis or anything else) only vectors.
Expand Down
10 changes: 9 additions & 1 deletion medcat/cdbbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self) -> None:
self.snames: Set
self.cui2names: Dict[str, Set[str]]
self.cui2snames: Dict[str, Set[str]]
self.cui2context_vectors: Dict[str, Dict[str, np.array]]
self.cui2context_vectors: Dict[str, Dict[str, np.ndarray]]
self.cui2count_train: Dict[str, int]
self.cui2info: Dict
self.cui2tags: Dict[str, List[str]]
Expand Down Expand Up @@ -286,3 +286,11 @@ def get_hash(self, force_recalc: bool = False) -> str:
Returns:
str: The hash for this CDB.
"""

@abstractmethod
def make_stats(self) -> dict:
pass

@abstractmethod
def update_cui2average_confidence(self, cui: str, new_sim: float) -> None:
pass
4 changes: 2 additions & 2 deletions medcat/linking/context_based_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Dict
from medcat.linking.vector_context_model import ContextModel
from medcat.pipeline.pipe_runner import PipeRunner
from medcat.cdb import CDB
from medcat.cdbbase import CDBBase
from medcat.vocab import Vocab
from medcat.config import Config
from medcat.utils.postprocessing import map_ents_to_groups, make_pretty_labels, create_main_ann, LabelStyle
Expand All @@ -27,7 +27,7 @@ class Linker(PipeRunner):
name = 'cat_linker'

# Override
def __init__(self, cdb: CDB, vocab: Vocab, config: Config) -> None:
def __init__(self, cdb: CDBBase, vocab: Vocab, config: Config) -> None:
self.cdb = cdb
self.vocab = vocab
self.config = config
Expand Down
4 changes: 2 additions & 2 deletions medcat/linking/vector_context_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Tuple, Dict, List, Union
from spacy.tokens import Span, Doc
from medcat.utils.matutils import unitvec
from medcat.cdb import CDB
from medcat.cdbbase import CDBBase
from medcat.vocab import Vocab
from medcat.config import Config
import random
Expand All @@ -21,7 +21,7 @@ class ContextModel(object):
config (Config): The config to be used
"""

def __init__(self, cdb: CDB, vocab: Vocab, config: Config) -> None:
def __init__(self, cdb: CDBBase, vocab: Vocab, config: Config) -> None:
self.cdb = cdb
self.vocab = vocab
self.config = config
Expand Down
4 changes: 2 additions & 2 deletions medcat/ner/vocab_based_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import logging
from typing import List, Optional
from spacy.tokens import Span, Token, Doc
from medcat.cdb import CDB
from medcat.cdbbase import CDBBase
from medcat.config import Config

logger = logging.getLogger(__name__)


def maybe_annotate_name(name: str, tkns: List[Token], doc: Doc, cdb: CDB, config: Config, label: str = 'concept') -> Optional[Span]:
def maybe_annotate_name(name: str, tkns: List[Token], doc: Doc, cdb: CDBBase, config: Config, label: str = 'concept') -> Optional[Span]:
"""Given a name it will check should it be annotated based on config rules. If yes
the annotation will be added to the doc._.ents array.
Expand Down
4 changes: 2 additions & 2 deletions medcat/ner/vocab_based_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from spacy.tokens import Doc
from medcat.ner.vocab_based_annotator import maybe_annotate_name
from medcat.pipeline.pipe_runner import PipeRunner
from medcat.cdb import CDB
from medcat.cdbbase import CDBBase
from medcat.config import Config


Expand All @@ -15,7 +15,7 @@ class NER(PipeRunner):
name = 'cat_ner'

# Override
def __init__(self, cdb: CDB, config: Config) -> None:
def __init__(self, cdb: CDBBase, config: Config) -> None:
self.config = config
self.cdb = cdb
super().__init__(self.config.general.workers)
Expand Down
27 changes: 14 additions & 13 deletions medcat/stats/stats.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Set, Tuple, Callable, List
from typing import Dict, Optional, Set, Tuple, Callable, List, cast

from tqdm import tqdm
import traceback
Expand All @@ -16,11 +16,11 @@ class StatsBuilder:
def __init__(self,
filters: LinkingFilters,
addl_info: dict,
doc_getter: Callable[[str], Doc],
doc_annotation_getter: Callable[[Doc], list],
doc_getter: Callable[[Optional[str], bool], Optional[Doc]],
doc_annotation_getter: Callable[[dict], list],
cui2group: Dict[str, str],
cui2preferred_name: Dict[str, str],
cui2names: Dict[str, List[str]],
cui2names: Dict[str, Set[str]],
use_project_filters: bool = False,
use_overlaps: bool = False,
use_cui_doc_limit: bool = False,
Expand Down Expand Up @@ -68,9 +68,10 @@ def process_project(self, project: dict) -> None:
total=len(documents),
leave=False,
):
self.process_document(project.get('name'), project.get('id'), doc)
self.process_document(cast(str, project.get('name')),
cast(str, project.get('id')), doc)

def process_document(self, project_name: str, project_id: str, doc: Doc) -> None:
def process_document(self, project_name: str, project_id: str, doc: dict) -> None:
anns = self._get_doc_annotations(doc)

# Apply document level filtering, in this case project_filter is ignored while the extra_cui_filter is respected still
Expand All @@ -97,7 +98,7 @@ def process_document(self, project_name: str, project_id: str, doc: Doc) -> None
p_anns_norm, p_anns_examples)
self._process_anns_norm(doc, anns_norm, p_anns_norm, anns_examples)

def _process_anns_norm(self, doc: Doc, anns_norm: list, p_anns_norm: list,
def _process_anns_norm(self, doc: dict, anns_norm: list, p_anns_norm: list,
anns_examples: list) -> None:
for iann, ann in enumerate(anns_norm):
if ann not in p_anns_norm:
Expand All @@ -108,7 +109,7 @@ def _process_anns_norm(self, doc: Doc, anns_norm: list, p_anns_norm: list,
self.fns[cui] = self.fns.get(cui, 0) + 1
self.examples['fn'][cui] = self.examples['fn'].get(cui, []) + [anns_examples[iann]]

def _process_p_anns(self, project_name: str, project_id: str, doc: Doc, p_anns: list) -> Tuple[list, list]:
def _process_p_anns(self, project_name: str, project_id: str, doc: dict, p_anns: list) -> Tuple[list, list]:
p_anns_norm = []
p_anns_examples = []
for ann in p_anns:
Expand All @@ -120,7 +121,7 @@ def _process_p_anns(self, project_name: str, project_id: str, doc: Doc, p_anns:
p_anns_examples.append(self._create_annoation_2(project_name, project_id, cui, doc, ann))
return p_anns_norm, p_anns_examples

def _count_p_anns_norm(self, doc: Doc, anns_norm: list, anns_norm_neg: list,
def _count_p_anns_norm(self, doc: dict, anns_norm: list, anns_norm_neg: list,
p_anns_norm: list, p_anns_examples: list) -> None:
for iann, ann in enumerate(p_anns_norm):
cui = ann[1]
Expand All @@ -143,7 +144,7 @@ def _count_p_anns_norm(self, doc: Doc, anns_norm: list, anns_norm_neg: list,

self.examples['fp'][cui] = self.examples['fp'].get(cui, []) + [example]

def _create_annoation(self, project_name: str, project_id: str, cui: str, doc: Doc, ann: Dict) -> Dict:
def _create_annoation(self, project_name: str, project_id: str, cui: str, doc: dict, ann: Dict) -> Dict:
return {"text": doc['text'][max(0, ann['start']-60):ann['end']+60],
"cui": cui,
"start": ann['start'],
Expand All @@ -155,7 +156,7 @@ def _create_annoation(self, project_name: str, project_id: str, cui: str, doc: D
"project id": project_id,
"document id": doc.get('id')}

def _create_annoation_2(self, project_name: str, project_id: str, cui: str, doc: Doc, ann) -> Dict:
def _create_annoation_2(self, project_name: str, project_id: str, cui: str, doc: dict, ann) -> Dict:
return {"text": doc['text'][max(0, ann.start_char-60):ann.end_char+60],
"cui": cui,
"start": ann.start_char,
Expand All @@ -168,7 +169,7 @@ def _create_annoation_2(self, project_name: str, project_id: str, cui: str, doc:
"document id": doc.get('id')}

def _preprocess_annotations(self, project_name: str, project_id: str,
doc: Doc, anns: List[Dict]) -> Tuple[list, list, list, list]:
doc: dict, anns: List[Dict]) -> Tuple[list, list, list, list]:
anns_norm = []
anns_norm_neg = []
anns_examples = []
Expand Down Expand Up @@ -255,7 +256,7 @@ def from_cat(cls, cat: CATBase,
extra_cui_filter: Optional[Set] = None) -> 'StatsBuilder':
return StatsBuilder(filters=local_filters,
addl_info=cat.cdb.addl_info,
doc_getter=cat,
doc_getter=cat.__call__,
doc_annotation_getter=cat._get_doc_annotations,
cui2group=cat.cdb.addl_info['cui2group'],
cui2preferred_name=cat.cdb.cui2preferred_name,
Expand Down
3 changes: 2 additions & 1 deletion medcat/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from spacy.tokens.doc import Doc
from spacy.tokens.span import Span
from medcat.cdb import CDB
from medcat.cdbbase import CDBBase
from collections import defaultdict
import random

Expand Down Expand Up @@ -791,7 +792,7 @@ def prepare_from_json_chars(data: Dict,
return out_data


def make_mc_train_test(data: Dict, cdb: CDB, test_size: float = 0.2) -> Tuple:
def make_mc_train_test(data: Dict, cdb: CDBBase, test_size: float = 0.2) -> Tuple:
"""This is a disaster."""
cnts: Dict = {}
total_anns = 0
Expand Down
4 changes: 2 additions & 2 deletions medcat/utils/ner/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from medcat.ner.transformers_ner import TransformersNER
from medcat.cat import CAT
from medcat.cdb import CDB
from medcat.cdbbase import CDBBase
from medcat.config import Config


Expand Down Expand Up @@ -75,7 +75,7 @@ def config(self) -> Config:
return self.cat.config

@property
def cdb(self) -> CDB:
def cdb(self) -> CDBBase:
return self.cat.cdb

@classmethod
Expand Down
8 changes: 4 additions & 4 deletions medcat/utils/postprocessing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from spacy.tokens import Span, Doc
from typing import Optional, List
from medcat.cdb import CDB
from medcat.cdbbase import CDBBase
from enum import Enum, auto


Expand All @@ -9,12 +9,12 @@ class LabelStyle(Enum):
long = auto()


def map_ents_to_groups(cdb: CDB, doc: Doc) -> None:
def map_ents_to_groups(cdb: CDBBase, doc: Doc) -> None:
for ent in doc.ents:
ent._.cui = cdb.addl_info['cui2group'].get(ent._.cui, ent._.cui)


def make_pretty_labels(cdb: CDB, doc: Doc, style: Optional[LabelStyle] = None) -> None:
def make_pretty_labels(cdb: CDBBase, doc: Doc, style: Optional[LabelStyle] = None) -> None:
ents = list(doc.ents)

n_ents = []
Expand All @@ -34,7 +34,7 @@ def make_pretty_labels(cdb: CDB, doc: Doc, style: Optional[LabelStyle] = None) -
doc.ents = n_ents # type: ignore


def create_main_ann(cdb: CDB, doc: Doc, tuis: Optional[List] = None) -> None:
def create_main_ann(cdb: CDBBase, doc: Doc, tuis: Optional[List] = None) -> None:
# TODO: Separate into another piece of the pipeline
"""Creates annotation in the spacy ents list
from all the annotations for this document.
Expand Down
8 changes: 4 additions & 4 deletions medcat/utils/regression/targeting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel

from medcat.cdb import CDB
from medcat.cdbbase import CDBBase

from medcat.utils.regression.utils import loosely_match_enum

Expand All @@ -25,12 +25,12 @@ class TranslationLayer:
Args:
cui2names (Dict[str, Set[str]]): The map from CUI to names
name2cuis (Dict[str, Set[str]]): The map from name to CUIs
name2cuis (Dict[str, List[str]]): The map from name to CUIs
cui2type_ids (Dict[str, Set[str]]): The map from CUI to type_ids
cui2children (Dict[str, Set[str]]): The map from CUI to child CUIs
"""

def __init__(self, cui2names: Dict[str, Set[str]], name2cuis: Dict[str, Set[str]],
def __init__(self, cui2names: Dict[str, Set[str]], name2cuis: Dict[str, List[str]],
cui2type_ids: Dict[str, Set[str]], cui2children: Dict[str, Set[str]]) -> None:
self.cui2names = cui2names
self.name2cuis = name2cuis
Expand Down Expand Up @@ -144,7 +144,7 @@ def get_parents_of(self, found_cuis: Iterable[str], cui: str, depth: int = 1) ->
return found_parents

@classmethod
def from_CDB(cls, cdb: CDB) -> 'TranslationLayer':
def from_CDB(cls, cdb: CDBBase) -> 'TranslationLayer':
"""Construct a TranslationLayer object from a context database (CDB).
This translation layer will refer to the same dicts that the CDB refers to.
Expand Down

0 comments on commit c6579be

Please sign in to comment.