Skip to content

Commit

Permalink
refactor: moving code into their own modules
Browse files Browse the repository at this point in the history
  • Loading branch information
dilawar committed Jul 26, 2024
1 parent b59d3be commit 843632b
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 84 deletions.
1 change: 1 addition & 0 deletions plotdigitizer/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

img_ = np.zeros((1, 1))


def cache() -> Path:
c = Path(tempfile.gettempdir()) / "plotdigitizer"
c.mkdir(parents=True, exist_ok=True)
Expand Down
61 changes: 59 additions & 2 deletions plotdigitizer/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,26 @@
from plotdigitizer import common
from plotdigitizer import geometry
from plotdigitizer import grid
from plotdigitizer import plot
from plotdigitizer.trajectory import find_trajectory


def click_points(event, x, y, _flags, params):
"""callback for opencv image"""
assert common.img_ is not None, "No data set"
# Function to record the clicks.
YROWS = common.img_.shape[0]
if event == cv.EVENT_LBUTTONDOWN:
logger.info(f"You clicked on {(x, YROWS-y)}")
common.locations_.append(geometry.Point(x, YROWS - y))


class Figure:
def __init__(self, path: Path):
def __init__(self, path: Path, coordinates: T.List[str], indices: T.List[str]):
assert path.exists(), f"{path} does not exists."
logger.info(f"Reading {path}")
self.indices = list_to_points(indices)
self.coordinates = list_to_points(coordinates)
self.path = path
self.orignal = cv.imread(self.path)
self.imgs = [("orig-gray-normalized", normalize(cv.imread(self.path, 0)))]
Expand All @@ -37,11 +51,30 @@ def trajectories(self):
def extract_trajectories(self):
logger.info(f"Extracting trajectories from {infile}")

def map_axis(self):
logger.info("Mapping axis...")
logger.debug(
f"data points {self.coordinates} → location on image {self.indices}"
)

if len(self.coordinates) != len(self.indices):
logger.warning(
"Either the location of data-points on the image is not specified or their numbers don't"
" match with given datapoints. Asking user to fill the missing information..."
)

# next function uses callback. Needs a global variable to collect
# data.
common.locations_ = self.indices
self.indices = ask_user_to_locate_points(self.coordinates, self._last())
assert len(self.coordinates) == len(self.indices)

def _last(self):
return self.imgs[-1][1]

def _append(self, operation: str, img):
self.imgs.append((operation, img))
self.imgs.append((operation, img))


def process_image(img, cache_key: T.Optional[str] = None):
global params_
Expand Down Expand Up @@ -157,3 +190,27 @@ def save_img_in_cache(
def normalize(img):
"""normalize image to 0, 255"""
return np.interp(img, (img.min(), img.max()), (0, 255)).astype(np.uint8)


def list_to_points(points) -> T.List[geometry.Point]:
ps = [geometry.Point.fromCSV(x) for x in points]
return ps


def ask_user_to_locate_points(points, img) -> list:
"""Ask user to map axis. Callback function save selected points in
common.locations_"""
cv.namedWindow(common.WindowName_)
cv.setMouseCallback(common.WindowName_, click_points)
while len(common.locations_) < len(points):
i = len(common.locations_)
p = points[i]
pLeft = len(points) - len(common.locations_)
plot.show_frame(img, "Please click on %s (%d left)" % (p, pLeft))
if len(common.locations_) == len(points):
break
key = cv.waitKey(1) & 0xFF
if key == "q":
break
logger.info("You clicked %s" % common.locations_)
return common.locations_
50 changes: 50 additions & 0 deletions plotdigitizer/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# helper function for plotting.

import typing as T
from pathlib import Path
import matplotlib.pyplot as plt
import cv2 as cv
import numpy as np
from loguru import logger

from plotdigitizer import common


def show_frame(img, msg="MSG: "):
msgImg = np.zeros(shape=(50, img.shape[1]))
cv.putText(msgImg, msg, (1, 40), 0, 0.5, 255)
newImg = np.vstack((img, msgImg.astype(np.uint8)))
cv.imshow(common.WindowName_, newImg)


def plot_traj(traj, img, outfile: T.Optional[Path] = None):
global locations_
import matplotlib.pyplot as plt

x, y = zip(*traj)
plt.figure()
plt.subplot(211)

for p in common.locations_:
csize = img.shape[0] // 40
cv.circle(
img,
(int(p.x), int(img.shape[0] - p.y)),
int(csize),
(128, 128, 128),
-1,
)

plt.imshow(img, interpolation="none", cmap="gray")
plt.axis(False)
plt.title("Original")
plt.subplot(212)
plt.title("Reconstructed")
plt.plot(x, y)
plt.tight_layout()
if not str(outfile):
plt.show()
else:
plt.savefig(outfile)
logger.info(f"Saved to {outfile}")
plt.close()
88 changes: 6 additions & 82 deletions plotdigitizer/plotdigitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
import typing as T
from pathlib import Path

import cv2 as cv

import numpy as np
import typer
from typing_extensions import Annotated

from plotdigitizer import grid
from plotdigitizer import image
from plotdigitizer import plot
from plotdigitizer import geometry
from plotdigitizer import common

Expand All @@ -22,72 +20,6 @@
app = typer.Typer()


def plot_traj(traj, outfile: Path):
global locations_
import matplotlib.pyplot as plt

x, y = zip(*traj)
plt.figure()
plt.subplot(211)

for p in common.locations_:
csize = common.img_.shape[0] // 40
cv.circle(
common.img_, (int(p.x), int(common.img_.shape[0] - p.y)), int(csize), (128, 128, 128), -1
)

plt.imshow(common.img_, interpolation="none", cmap="gray")
plt.axis(False)
plt.title("Original")
plt.subplot(212)
plt.title("Reconstructed")
plt.plot(x, y)
plt.tight_layout()
if not str(outfile):
plt.show()
else:
plt.savefig(outfile)
logger.info(f"Saved to {outfile}")
plt.close()


def click_points(event, x, y, _flags, params):
assert common.img_ is not None, "No data set"
# Function to record the clicks.
YROWS = common.img_.shape[0]
if event == cv.EVENT_LBUTTONDOWN:
logger.info(f"You clicked on {(x, YROWS-y)}")
common.locations_.append(geometry.Point(x, YROWS - y))


def show_frame(img, msg="MSG: "):
msgImg = np.zeros(shape=(50, img.shape[1]))
cv.putText(msgImg, msg, (1, 40), 0, 0.5, 255)
newImg = np.vstack((img, msgImg.astype(np.uint8)))
cv.imshow(common.WindowName_, newImg)


def ask_user_to_locate_points(points, img):
cv.namedWindow(common.WindowName_)
cv.setMouseCallback(common.WindowName_, click_points)
while len(common.locations_) < len(points):
i = len(common.locations_)
p = points[i]
pLeft = len(points) - len(common.locations_)
show_frame(img, "Please click on %s (%d left)" % (p, pLeft))
if len(common.locations_) == len(points):
break
key = cv.waitKey(1) & 0xFF
if key == "q":
break
logger.info("You clicked %s" % common.locations_)


def list_to_points(points) -> T.List[geometry.Point]:
ps = [geometry.Point.fromCSV(x) for x in points]
return ps


@app.command()
def digitize_plot(
infile: Path,
Expand Down Expand Up @@ -122,27 +54,19 @@ def digitize_plot(
),
] = None,
):
figure = image.Figure(infile)
figure = image.Figure(infile, data_point, location)

# remove grids.
figure.remove_grid()
image.save_img_in_cache(common.img_, infile.name)

common.points_ = list_to_points(data_point)
common.locations_ = list_to_points(location)
logger.debug(f"data points {data_point} → location on image {location}")

if len(common.locations_) != len(common.points_):
logger.warning(
"Either the location of data-points are not specified or their numbers don't"
" match with given datapoints. Asking user..."
)
ask_user_to_locate_points(common.points_, common.img_)
# map the axis
figure.map_axis()

# compute trajectories
traj = figure.trajectories()

if plot_file is not None:
plot_traj(traj, plot_file)
plot.plot_traj(traj, figure._last(), plot_file)

outfile = output or f"{infile}.traj.csv"
with open(outfile, "w") as f:
Expand Down

0 comments on commit 843632b

Please sign in to comment.