Skip to content

Latest commit

 

History

History
59 lines (43 loc) · 1.71 KB

NEW_METHOD.md

File metadata and controls

59 lines (43 loc) · 1.71 KB

Add new Method

A new Method for OOD Detection needs to be an instance of Class OODMethod. As shown below the only methods that must be specified is the calculate_ood_score method which returns a batch of scores (classification) or masks of scores (Segmentation) for a Sample Batch.

class OODMethod:
    def __init__(self, hyperparams={}) -> None:
        pass

    def get_params(self, dict=False):
        return "no params" if not dict else {}

    def display_name(self) -> str:
        return f"{type(self).__name__}"

    def modify_net(self, net):
        return net

    def remodify_net(self, net):
        return net

    def calculate_ood_score(self, imgs, net, batch=None):
        raise NotImplementedError("Please Implement this method")

    def _set_metadata(self, batch: Batch, scores) -> Batch:
        batch["metadata"].append_to_keyed_dict(
            SampleMetadataCommonTypes.OOD_SCORE.name,
            self.display_name(),
            scores,
        )
        return batch

    def __call__(
        self,
        batch: Batch,
        net,
        task_type: TaskType = TaskType.SEGMENTATION,
        remove_ignore_index: bool = True,
        score_reduction_method: Callable = np.mean,
    ) -> Sample:
        # modify network if needed
        net = self.modify_net(net)

        scores = self.calculate_ood_score(batch.image, net, batch).numpy()

        ##########################
        # POSTPROCESSING removed #
        ##########################

        # save scores in metadata to later be accessed by the metrics
        batch = self._set_metadata(batch, scores)

        # always remodify network so it works properly for other mothods
        self.remodify_net(net)

        return batch