From e0c64561eb7c839f42e5a32188789981d6c3e6d3 Mon Sep 17 00:00:00 2001 From: Mart Ratas Date: Thu, 21 Sep 2023 16:21:00 +0300 Subject: [PATCH] CU-8692mevx8 Fix issue with filters not taking effect in train_supervised method (#345) * CU-8692mevx8 Fix issue with filters not taking effect in train_supervised method * CU-8692mevx8 Fix filter retention in train_supervised method --- medcat/cat.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/medcat/cat.py b/medcat/cat.py index 5218e9d02..2323cd737 100644 --- a/medcat/cat.py +++ b/medcat/cat.py @@ -490,7 +490,8 @@ def _print_stats(self, fp_docs: Set = set() fn_docs: Set = set() - local_filters = self.config.linking.filters.copy_of() + orig_filters = self.config.linking.filters.copy_of() + local_filters = self.config.linking.filters for pind, project in tqdm(enumerate(data['projects']), desc="Stats project", total=len(data['projects']), leave=False): local_filters.cuis = set() @@ -645,6 +646,8 @@ def _print_stats(self, except Exception: traceback.print_exc() + self.config.linking.filters = orig_filters + return fps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples def _set_project_filters(self, local_filters: LinkingFilters, project: dict, @@ -1033,7 +1036,13 @@ def train_supervised_raw(self, """ checkpoint = self._init_ckpts(is_resumed, checkpoint) - local_filters = self.config.linking.filters.copy_of() + # the config.linking.filters stuff is used directly in + # medcat.linking.context_based_linker and medcat.linking.vector_context_model + # as such, they need to be kept up to date with per-project filters + # However, the original state needs to be kept track of + # so that it can be restored after training + orig_filters = self.config.linking.filters.copy_of() + local_filters = self.config.linking.filters fp = fn = tp = p = r = f1 = examples = {} @@ -1094,7 +1103,7 @@ def train_supervised_raw(self, if retain_filters and extra_cui_filter and not retain_extra_cui_filter: # adding project filters without extra_cui_filters self._set_project_filters(local_filters, project, set(), use_filters) - self.config.linking.filters.merge_with(local_filters) + orig_filters.merge_with(local_filters) # adding extra_cui_filters, but NOT project filters self._set_project_filters(local_filters, project, extra_cui_filter, False) # refrain from doing it again for subsequent epochs @@ -1140,7 +1149,7 @@ def train_supervised_raw(self, checkpoint.save(self.cdb, latest_trained_step) # if retaining MCT filters AND (if they exist) extra_cui_filters if retain_filters: - self.config.linking.filters.merge_with(local_filters) + orig_filters.merge_with(local_filters) # refrain from doing it again for subsequent epochs retain_filters = False @@ -1162,6 +1171,9 @@ def train_supervised_raw(self, use_groups=use_groups, extra_cui_filter=extra_cui_filter) + # reset the state of filters + self.config.linking.filters = orig_filters + return fp, fn, tp, p, r, f1, cui_counts, examples def get_entities(self,