Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Custom Models #3

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions autodistill_owl_vit/owlvit.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,59 @@
import os
from dataclasses import dataclass
from typing import Optional, Union

import numpy as np
import supervision as sv
import torch
from transformers import OwlViTForObjectDetection, OwlViTProcessor

from autodistill.detection import CaptionOntology, DetectionBaseModel
from autodistill.helpers import load_image
from transformers import OwlViTForObjectDetection, OwlViTProcessor

HOME = os.path.expanduser("~")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(
DEVICE
)
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")


@dataclass
class OWLViT(DetectionBaseModel):
ontology: CaptionOntology
owlvit_model: model

def __init__(self, ontology: CaptionOntology):
def __init__(
self,
ontology: CaptionOntology,
model: Optional[Union[str, os.PathLike]] = "google/owlvit-base-patch32",
):
self.ontology = ontology
self.owlvit_model = model
self.processor = OwlViTProcessor.from_pretrained(model)
self.model = OwlViTForObjectDetection.from_pretrained(model).to(DEVICE)

def predict(self, input: str, confidence = 0.1) -> sv.Detections:
def predict(self, input: str, confidence=0.1) -> sv.Detections:
labels = self.ontology.prompts()

image = load_image(input, return_format="PIL")

with torch.no_grad():
inputs = processor(text=labels, images=image, return_tensors="pt")
outputs = model(**inputs)
inputs = self.processor(text=labels, images=image, return_tensors="pt").to(
DEVICE
)
outputs = self.model(**inputs)

target_sizes = torch.Tensor([image.size[::-1]])

results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes)
results = self.processor.post_process_object_detection(
outputs=outputs, target_sizes=target_sizes
)

i = 0

boxes = results[i]["boxes"].tolist()
scores = results[i]["scores"].tolist()
labels = results[i]["labels"].tolist()

print(scores)

# filter with score < confidence
boxes = [box for box, score in zip(boxes, scores) if score > confidence]
labels = [label for label, score in zip(labels, scores) if score > confidence]
labels = [
label for label, score in zip(labels, scores) if score > confidence
]
scores = [score for score in scores if score > confidence]

if len(boxes) == 0:
Expand Down
Loading