Skip to content

Commit

Permalink
CU-8692mevx8 Fix issue with filters not taking effect in train_superv…
Browse files Browse the repository at this point in the history
…ised method (#345)

* CU-8692mevx8 Fix issue with filters not taking effect in train_supervised method

* CU-8692mevx8 Fix filter retention in train_supervised method
  • Loading branch information
mart-r authored Sep 21, 2023
1 parent 3aaef44 commit e0c6456
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,8 @@ def _print_stats(self,
fp_docs: Set = set()
fn_docs: Set = set()

local_filters = self.config.linking.filters.copy_of()
orig_filters = self.config.linking.filters.copy_of()
local_filters = self.config.linking.filters
for pind, project in tqdm(enumerate(data['projects']), desc="Stats project", total=len(data['projects']), leave=False):
local_filters.cuis = set()

Expand Down Expand Up @@ -645,6 +646,8 @@ def _print_stats(self,
except Exception:
traceback.print_exc()

self.config.linking.filters = orig_filters

return fps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples

def _set_project_filters(self, local_filters: LinkingFilters, project: dict,
Expand Down Expand Up @@ -1033,7 +1036,13 @@ def train_supervised_raw(self,
"""
checkpoint = self._init_ckpts(is_resumed, checkpoint)

local_filters = self.config.linking.filters.copy_of()
# the config.linking.filters stuff is used directly in
# medcat.linking.context_based_linker and medcat.linking.vector_context_model
# as such, they need to be kept up to date with per-project filters
# However, the original state needs to be kept track of
# so that it can be restored after training
orig_filters = self.config.linking.filters.copy_of()
local_filters = self.config.linking.filters

fp = fn = tp = p = r = f1 = examples = {}

Expand Down Expand Up @@ -1094,7 +1103,7 @@ def train_supervised_raw(self,
if retain_filters and extra_cui_filter and not retain_extra_cui_filter:
# adding project filters without extra_cui_filters
self._set_project_filters(local_filters, project, set(), use_filters)
self.config.linking.filters.merge_with(local_filters)
orig_filters.merge_with(local_filters)
# adding extra_cui_filters, but NOT project filters
self._set_project_filters(local_filters, project, extra_cui_filter, False)
# refrain from doing it again for subsequent epochs
Expand Down Expand Up @@ -1140,7 +1149,7 @@ def train_supervised_raw(self,
checkpoint.save(self.cdb, latest_trained_step)
# if retaining MCT filters AND (if they exist) extra_cui_filters
if retain_filters:
self.config.linking.filters.merge_with(local_filters)
orig_filters.merge_with(local_filters)
# refrain from doing it again for subsequent epochs
retain_filters = False

Expand All @@ -1162,6 +1171,9 @@ def train_supervised_raw(self,
use_groups=use_groups,
extra_cui_filter=extra_cui_filter)

# reset the state of filters
self.config.linking.filters = orig_filters

return fp, fn, tp, p, r, f1, cui_counts, examples

def get_entities(self,
Expand Down

0 comments on commit e0c6456

Please sign in to comment.