Skip to content

Commit

Permalink
event threshold as list
Browse files Browse the repository at this point in the history
  • Loading branch information
kapoorlab committed Aug 26, 2023
1 parent 5736469 commit 3d14249
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/oneat/NEATModels/neat_densevollnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tensorflow.keras.models import load_model
from tifffile import imread
from sklearn.utils import class_weight

from typing import Union, List
class NEATDenseVollNet(object):
"""
Parameters
Expand Down
14 changes: 10 additions & 4 deletions src/oneat/NEATUtils/HolovizNapari.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,15 @@ def cluster_spheres(event_locations_dict, event_locations_size_dict, nms_space,

def headlesscall(
key_categories: dict,
event_threshold: float,
event_threshold: list,
nms_space: int,
nms_time: int,
csvdir: str,
savedir: str,
):
if isinstance(event_threshold, float):
event_threshold = [event_threshold] * len(key_categories)

for (event_name, event_label) in key_categories.items():
if event_label > 0:
event_locations = []
Expand Down Expand Up @@ -377,7 +380,7 @@ def headlesscall(
size = float(listsize[i])
score = float(listscore[i])
confidence = listconfidence[i]
if score > event_threshold:
if score > event_threshold[event_label]:
event_locations.append(
[int(tcenter), int(ycenter), int(xcenter)]
)
Expand Down Expand Up @@ -470,12 +473,15 @@ def headlesscall(

def headlessvolumecall(
key_categories: dict,
event_threshold: float,
event_threshold: list,
nms_space: int,
nms_time: int,
csvdir: str,
savedir: str,
):
if isinstance(event_threshold, float):
event_threshold = [event_threshold] * len(key_categories)

for (event_name, event_label) in key_categories.items():
if event_label > 0:
event_locations = []
Expand Down Expand Up @@ -523,7 +529,7 @@ def headlessvolumecall(
size = listsize[i]
score = listscore[i]
confidence = listconfidence[i]
if score > event_threshold:
if score > event_threshold[event_label]:
event_locations.append(
[int(tcenter), int(zcenter), int(ycenter), int(xcenter)]
)
Expand Down
4 changes: 2 additions & 2 deletions src/oneat/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# file generated by setuptools_scm
# don't change, don't track in version control
__version__ = version = '6.6.1'
__version_tuple__ = version_tuple = (6, 6, 1)
__version__ = version = '6.6.9'
__version_tuple__ = version_tuple = (6, 6, 9)

0 comments on commit 3d14249

Please sign in to comment.