From 3d142496e65c152252d4bf8543bb48869b326cdb Mon Sep 17 00:00:00 2001 From: kapoorlab Date: Sat, 26 Aug 2023 16:56:56 +0200 Subject: [PATCH] event threshold as list --- src/oneat/NEATModels/neat_densevollnet.py | 2 +- src/oneat/NEATUtils/HolovizNapari.py | 14 ++++++++++---- src/oneat/_version.py | 4 ++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/oneat/NEATModels/neat_densevollnet.py b/src/oneat/NEATModels/neat_densevollnet.py index 8a92a59..be12848 100644 --- a/src/oneat/NEATModels/neat_densevollnet.py +++ b/src/oneat/NEATModels/neat_densevollnet.py @@ -15,7 +15,7 @@ from tensorflow.keras.models import load_model from tifffile import imread from sklearn.utils import class_weight - +from typing import Union, List class NEATDenseVollNet(object): """ Parameters diff --git a/src/oneat/NEATUtils/HolovizNapari.py b/src/oneat/NEATUtils/HolovizNapari.py index 8446fa6..91f410c 100644 --- a/src/oneat/NEATUtils/HolovizNapari.py +++ b/src/oneat/NEATUtils/HolovizNapari.py @@ -327,12 +327,15 @@ def cluster_spheres(event_locations_dict, event_locations_size_dict, nms_space, def headlesscall( key_categories: dict, - event_threshold: float, + event_threshold: list, nms_space: int, nms_time: int, csvdir: str, savedir: str, ): + if isinstance(event_threshold, float): + event_threshold = [event_threshold] * len(key_categories) + for (event_name, event_label) in key_categories.items(): if event_label > 0: event_locations = [] @@ -377,7 +380,7 @@ def headlesscall( size = float(listsize[i]) score = float(listscore[i]) confidence = listconfidence[i] - if score > event_threshold: + if score > event_threshold[event_label]: event_locations.append( [int(tcenter), int(ycenter), int(xcenter)] ) @@ -470,12 +473,15 @@ def headlesscall( def headlessvolumecall( key_categories: dict, - event_threshold: float, + event_threshold: list, nms_space: int, nms_time: int, csvdir: str, savedir: str, ): + if isinstance(event_threshold, float): + event_threshold = [event_threshold] * len(key_categories) + for (event_name, event_label) in key_categories.items(): if event_label > 0: event_locations = [] @@ -523,7 +529,7 @@ def headlessvolumecall( size = listsize[i] score = listscore[i] confidence = listconfidence[i] - if score > event_threshold: + if score > event_threshold[event_label]: event_locations.append( [int(tcenter), int(zcenter), int(ycenter), int(xcenter)] ) diff --git a/src/oneat/_version.py b/src/oneat/_version.py index 7433207..a4ef464 100644 --- a/src/oneat/_version.py +++ b/src/oneat/_version.py @@ -1,4 +1,4 @@ # file generated by setuptools_scm # don't change, don't track in version control -__version__ = version = '6.6.1' -__version_tuple__ = version_tuple = (6, 6, 1) +__version__ = version = '6.6.9' +__version_tuple__ = version_tuple = (6, 6, 9)