Skip to content

Commit

Permalink
major changes. Checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
patel-zeel committed Sep 26, 2024
1 parent edcfb19 commit 5546ca7
Show file tree
Hide file tree
Showing 13 changed files with 1,375 additions and 1,054 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ __pycache__/
*.pyc
*.egg-info/
garuda/_version.py
lab/*
!lab/*.ipynb
build
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ Latest version:
pip install git+https://github.com/patel-zeel/garuda
```

## Non-circular imports
* core -> box
* box -> utils
* core + box -> annotate

## Terminology

| Term | Description |
Expand Down Expand Up @@ -157,4 +162,8 @@ img = plt.imread('data/images/22.32,87.93.png')

fig, ax = plt.subplots()
ax = plot_webm_pixel_to_geo(img, img_center_lat, img_center_lon, zoom, ax)
```
```

### Why `OBBLabel` and not `OBBLabels`
* Non-vectorized operations are easier to understand and debug.
* It'd be easy for us to separate out exact false positives and false negatives at a single object level.
247 changes: 247 additions & 0 deletions garuda/annotate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
import os
import cv2
from glob import glob
import numpy as np
from ipyleaflet import GeomanDrawControl
import leafmap
from copy import deepcopy
import geojson
from ipywidgets import Button, Label, HBox, Dropdown
from IPython.display import display
from garuda.core import geo_to_webm_pixel, webm_pixel_to_geo, xywhr2xyxyxyxy
from garuda.box import OBBLabel
from shapely.geometry import Polygon

class AnnotationTool:
def __init__(self, labels, classes, zoom, cache_dir, clear_cache=False):
self.original_labels = deepcopy(labels)
self.labels = deepcopy(labels)
self.classes = classes
self.zoom = zoom
self.cache_dir = cache_dir
os.makedirs(self.cache_dir, exist_ok=True)
if clear_cache:
label_files = glob(os.path.join(self.cache_dir, "label_*.geojson"))
for label_file in label_files:
os.remove(label_file)

# initialize
self.index = 0

# Map
self.m = leafmap.Map(center=(27, 77), zoom=self.zoom)
self.m.add_basemap("Esri.WorldImagery")
self.m.remove_control(self.m.draw_control)
self.draw_control = GeomanDrawControl(position='topright')

def on_draw(*args, **kwargs):
self.status_label.value = "Submit the label to update it."
self.disable_buttons()

self.draw_control.on_draw(on_draw)

self.draw_control.rectangle = {
"pathOptions": {
"fillColor": "#fca45d",
"color": "#fca45d",
"fillOpacity": 0.0
}
}

self.m.add_control(self.draw_control)

# Interface elements
# current label
self.show_label = Label(f"Label {self.index+1}/{len(labels)}")

# status label
self.status_label = Label("")

# next_button
self.next_button = Button(description="next")
self.next_button.on_click(self.next_button_clicked)

# previous_button
self.previous_button = Button(description="previous")
self.previous_button.on_click(self.previous_button_clicked)

# submit button
self.submit_button = Button(description="submit")
self.submit_button.on_click(self.submit_button_clicked)

# reset button
self.reset_button = Button(description="reset_current_label")
self.reset_button.on_click(self.reset_button_clicked)

# classes dropdown
self.classes_dropdown = Dropdown(options=self.classes)
self.classes_dropdown.on_trait_change(self.on_dropdown_change, 'value')

display(self.show_label)
display(self.status_label)
display(HBox([self.submit_button, self.previous_button, self.next_button, self.reset_button, self.classes_dropdown]))
display(self.m)

# initialize
loaded_from_cache = self.show_current_label()
while loaded_from_cache:
self.next_button_clicked()
loaded_from_cache = self.show_current_label()

def show_current_label(self):
self.disable_buttons()
loaded_from_cache = False
if os.path.exists(f"{self.cache_dir}/label_{self.index}.geojson"):
with open(f"{self.cache_dir}/label_{self.index}.geojson", "r") as f:
data = f.read().strip()
if data == "Empty_label":
self.labels[self.index] = None
else:
feature = geojson.loads(data)['features'][0]
self.labels[self.index] = OBBLabel.from_geojson(feature)
self.enable_buttons() # allow to move around if label is already present
loaded_from_cache = True

label = self.labels[self.index]
if label is None:
self.draw_control.data = []
self.status_label.value = "No label available."
original_label = self.original_labels[self.index]
self.m.set_center(original_label.properties['center_lon'], original_label.properties['center_lat'], zoom=self.zoom)
self.enable_buttons() # allow to move around if label is empty
else:
if loaded_from_cache:
self.status_label.value = "Label loaded from cache. Submit only to make changes."
else:
self.status_label.value = "Label is a valid polygon. Make changes if needed and submit."
# set pov
self.m.set_center(label.properties['center_lon'], label.properties['center_lat'], zoom=self.zoom)

# show current label
feature = label.to_geojson(source=None, task_name=None)
# show the boundary in red color
feature['properties']['style'] = {'color': self.get_color(feature), 'fillColor': self.get_color(feature), 'fillOpacity': 0.0}
self.draw_control.data = [] # first clear the existing data to trigger the changes in GUI
self.draw_control.data = [feature]
self.classes_dropdown.value = feature['properties']['class_name']

# update label
self.show_label.value = f"Label {self.index+1}/{len(self.labels)}"
return loaded_from_cache

def disable_buttons(self):
self.next_button.disabled = True
# self.previous_button.disabled = True

def enable_buttons(self):
self.next_button.disabled = False
# self.previous_button.disabled = False

@staticmethod
def get_color(feature):
class_name = feature['properties']['class_name']
if class_name == "CFCBK":
# red
return "#ff0000"
elif class_name == "FCBK":
# orange
return "#ffa500"
elif class_name == "Zigzag":
# green
return "#00ff00"
else:
# blue
return "#0000ff"

def submit_button_clicked(self, *args, **kwargs):
if len(self.draw_control.data) == 0:
self.labels[self.index] = None
# remove label from cache
cache_path = f"{self.cache_dir}/label_{self.index}.geojson"
with open(cache_path, "w") as f:
f.write("Empty_label")
self.enable_buttons() # allow to move around if label is empty
return

feature = self.draw_control.data[-1]
try:
assert feature['geometry']['type'] == 'Polygon'
except AssertionError:
if feature['geometry']['type'] != 'Polygon':
self.status_label.value = "Invalid label. Please correct it or delete it."
return

coords = []
for lon, lat in feature['geometry']['coordinates'][0]:
x, y = geo_to_webm_pixel(lat, lon, self.zoom)
coords.append([x, y])
coords = np.array(coords, dtype=np.float32)

(x, y), (w, h), r = cv2.minAreaRect(coords)
r = np.deg2rad(r)
rect = xywhr2xyxyxyxy(np.array([x, y, w, h, r]))

coords = []
for pair in rect:
lat, lon = webm_pixel_to_geo(pair[0], pair[1], self.zoom)
coords.append([lon, lat])
poly = Polygon(coords)

feature['geometry']['coordinates'] = [list(poly.exterior.coords)]
feature['properties']['source'] = 'hand_validated'
feature['properties']['task_name'] = 'hand_validation'
feature['properties']['class_name'] = self.classes_dropdown.value
self.labels[self.index] = OBBLabel.from_geojson(feature)
self.cache_label()
self.show_current_label()
self.enable_buttons()

def cache_label(self):
cache_path = f"{self.cache_dir}/label_{self.index}.geojson"
feature = self.labels[self.index].to_geojson(source=None, task_name=None)
collection = geojson.FeatureCollection([feature])
with open(cache_path, "w") as f:
geojson.dump(collection, f)

def reset_button_clicked(self, *args, **kwargs):
original_label = self.original_labels[self.index]
self.labels[self.index] = original_label
self.cache_label()
self.show_current_label()

def on_dropdown_change(self, old, new):
if new != self.labels[self.index].properties['class_name']:
self.labels[self.index].properties['class_name'] = new
self.cache_label()
self.show_current_label()

def next_button_clicked(self, *args, **kwargs):
# show next label
if self.index >= (len(self.labels) - 1):
pass # do nothing
else:
self.index += 1
self.show_current_label()

def previous_button_clicked(self, *args, **kwargs):
# show next label
if self.index <= 0:
pass # do nothing
else:
self.index -= 1
self.show_current_label()

def to_geojson(self):
features = []
labels = glob(f"{self.cache_dir}/label_*.geojson")
for label in labels:
with open(label, "r") as f:
feature = geojson.load(f)['features'][0]
features.append(feature)
collection = geojson.FeatureCollection(features)
return collection

def save_to_geojson(self, save_path):
collection = self.to_geojson()
with open(save_path, "w") as f:
geojson.dump(collection, f)
Loading

0 comments on commit 5546ca7

Please sign in to comment.