Skip to content

Commit

Permalink
thresh as tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
kapoorlab committed Aug 23, 2023
1 parent 05c67be commit 5736469
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions src/oneat/NEATModels/neat_densevollnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ def predict(self,
savename: str = '',
n_tiles : tuple = (1, 1, 1),
overlap_percent : float =0.8,
event_threshold : float = 0.5,
event_confidence : float = 0.5,
event_threshold : list = [1,0.9,0.9],
event_confidence : list = [0.5,0.5,0.5],
iou_threshold : float = 0.1,
dtype : np.dtype = np.uint8,
marker_tree : dict = None,
Expand Down Expand Up @@ -295,7 +295,16 @@ def predict(self,

self.marker_tree = marker_tree
self.remove_markers = remove_markers

if not isinstance(self.event_threshold, list ):
thresh_list = []
for (event_name,event_label) in self.key_categories.items():
thresh_list.append(self.event_threshold)
self.event_threshold = thresh_list
if not isinstance(self.event_confidence, list ):
conf_list = []
for (event_name,event_label) in self.key_categories.items():
conf_list.append(self.event_threshold)
self.event_confidence = conf_list
#Normalize in volume
self.originalimage = normalizeFloatZeroOne(self.originalimage, 1, 99.8, dtype = self.dtype)
if self.remove_markers == True:
Expand Down Expand Up @@ -372,7 +381,7 @@ def default_pass_predict(self):

event_prob = box[event_name]
event_confidence = box['confidence']
if event_prob >= self.event_threshold and event_confidence >= self.event_confidence:
if event_prob >= self.event_threshold[event_label] and event_confidence >= self.event_confidence[event_label]:

current_event_box.append(box)
classedboxes[event_name] = [current_event_box]
Expand Down Expand Up @@ -540,7 +549,7 @@ def second_pass_predict(self):
for box in eventboxes:
event_prob = box[event_name]
event_confidence = box['confidence']
if event_prob >= self.event_threshold and event_confidence >= self.event_confidence :
if event_prob >= self.event_threshold[event_label] and event_confidence >= self.event_confidence[event_label] :
current_event_box.append(box)

classedboxes[event_name] = [current_event_box]
Expand All @@ -566,7 +575,7 @@ def fast_nms(self):
for (event_name,event_label) in self.key_categories.items():
if event_label == 0:
#best_sorted_event_box = self.classedboxes[event_name][0]
best_sorted_event_box = volume_dynamic_nms(self.classedboxes, event_name, self.iou_threshold, self.event_threshold, self.imagex, self.imagey, self.imagez, nms_function = self.nms_function )
best_sorted_event_box = volume_dynamic_nms(self.classedboxes, event_name, self.iou_threshold, self.event_threshold[event_label], self.imagex, self.imagey, self.imagez, nms_function = self.nms_function )

best_iou_classedboxes[event_name] = [best_sorted_event_box]

Expand All @@ -582,7 +591,7 @@ def nms(self):
for (event_name,event_label) in self.key_categories.items():
if event_label > 0:
#best_sorted_event_box = self.classedboxes[event_name][0]
best_sorted_event_box = volume_dynamic_nms(self.classedboxes, event_name, self.iou_threshold, self.event_threshold, self.imagex, self.imagey, self.imagez, nms_function = self.nms_function )
best_sorted_event_box = volume_dynamic_nms(self.classedboxes, event_name, self.iou_threshold, self.event_threshold[event_label], self.imagex, self.imagey, self.imagez, nms_function = self.nms_function )

best_iou_classedboxes[event_name] = [best_sorted_event_box]

Expand Down

0 comments on commit 5736469

Please sign in to comment.