From 0c3241c5aed5ce796b16cbeaa2d32cdef3dc5307 Mon Sep 17 00:00:00 2001 From: Daniel Wessel Date: Thu, 11 Jan 2024 15:51:02 +0100 Subject: [PATCH] Moved model initialization into class. Allows for passing custom huggingface model identifier. Fixed device. --- autodistill_owl_vit/owlvit.py | 37 +++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/autodistill_owl_vit/owlvit.py b/autodistill_owl_vit/owlvit.py index 2727f9e..8ce2655 100644 --- a/autodistill_owl_vit/owlvit.py +++ b/autodistill_owl_vit/owlvit.py @@ -1,44 +1,47 @@ import os from dataclasses import dataclass +from typing import Optional, Union import numpy as np import supervision as sv import torch -from transformers import OwlViTForObjectDetection, OwlViTProcessor - from autodistill.detection import CaptionOntology, DetectionBaseModel from autodistill.helpers import load_image +from transformers import OwlViTForObjectDetection, OwlViTProcessor HOME = os.path.expanduser("~") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") -model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to( - DEVICE -) -processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") - @dataclass class OWLViT(DetectionBaseModel): ontology: CaptionOntology - owlvit_model: model - def __init__(self, ontology: CaptionOntology): + def __init__( + self, + ontology: CaptionOntology, + model: Optional[Union[str, os.PathLike]] = "google/owlvit-base-patch32", + ): self.ontology = ontology - self.owlvit_model = model + self.processor = OwlViTProcessor.from_pretrained(model) + self.model = OwlViTForObjectDetection.from_pretrained(model).to(DEVICE) - def predict(self, input: str, confidence = 0.1) -> sv.Detections: + def predict(self, input: str, confidence=0.1) -> sv.Detections: labels = self.ontology.prompts() image = load_image(input, return_format="PIL") with torch.no_grad(): - inputs = processor(text=labels, images=image, return_tensors="pt") - outputs = model(**inputs) + inputs = self.processor(text=labels, images=image, return_tensors="pt").to( + DEVICE + ) + outputs = self.model(**inputs) target_sizes = torch.Tensor([image.size[::-1]]) - results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes) + results = self.processor.post_process_object_detection( + outputs=outputs, target_sizes=target_sizes + ) i = 0 @@ -46,11 +49,11 @@ def predict(self, input: str, confidence = 0.1) -> sv.Detections: scores = results[i]["scores"].tolist() labels = results[i]["labels"].tolist() - print(scores) - # filter with score < confidence boxes = [box for box, score in zip(boxes, scores) if score > confidence] - labels = [label for label, score in zip(labels, scores) if score > confidence] + labels = [ + label for label, score in zip(labels, scores) if score > confidence + ] scores = [score for score in scores if score > confidence] if len(boxes) == 0: