Skip to content

Commit

Permalink
instance metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
Tonks684 committed Aug 15, 2024
1 parent 7f20542 commit b38d4e0
Showing 1 changed file with 31 additions and 23 deletions.
54 changes: 31 additions & 23 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import matplotlib.pyplot as plt
from cellpose import models
from typing import List, Tuple
from numpy.typing import ArrayLike
import warnings
warnings.filterwarnings('ignore')

Expand Down Expand Up @@ -405,14 +406,14 @@
"""
# %%
# Define a function to crop the images so we can zoom in.
def crop(img, crop_size, type=None):
def crop(img, crop_size, loc='center'):
"""
Crop the input image.
Parameters:
img (ndarray): The image to be cropped.
crop_size (int): The size of the crop.
type (str): The type of crop to perform. Can be 'center' or 'random'.
loc (str): The type of crop to perform. Can be 'center' or 'random'.
Returns:
ndarray: The cropped image array.
Expand All @@ -424,18 +425,18 @@ def crop(img, crop_size, type=None):
max_y = height - crop_size
max_x = max_y

if type == 'random':
if loc == 'random':
start_y = np.random.randint(0, max_y + 1)
start_x = np.random.randint(0, max_x + 1)
end_y = start_y + crop_size
end_x = start_x + crop_size
elif type == 'center':
elif loc == 'center':
start_x = (width - crop_size) // 2
start_y = (height - crop_size) // 2
end_y = height - start_y
end_x = width - start_x
else:
raise ValueError(f'Unknown crop type {type}')
raise ValueError(f'Unknown crop type {loc}')

# Crop array using slicing
crop_array = img[start_x:end_x, start_y:end_y]
Expand All @@ -457,7 +458,7 @@ def visualise_results():
######## Solution ########
##########################

def visualise_results(phase_images, target_stains, virtual_stains, crop_size=None, type='center'):
def visualise_results(phase_images, target_stains, virtual_stains, crop_size=None, loc='center'):
"""
Visualizes the results of image processing by displaying the phase images, target stains, and virtual stains.
Parameters:
Expand All @@ -473,23 +474,28 @@ def visualise_results(phase_images, target_stains, virtual_stains, crop_size=Non
sample_indices = np.random.choice(len(phase_images), 5)
for index,sample in enumerate(sample_indices):
if crop_size:
phase_images[index] = crop(phase_images[index], crop_size, type)
target_stains[index] = crop(target_stains[index], crop_size, type)
virtual_stains[index] = crop(virtual_stains[index], crop_size, type)
axes[index, 0].imshow(phase_images[index], cmap="gray")
phase_image = crop(phase_images[index], crop_size, loc)
target_stain = crop(target_stains[index], crop_size, loc)
virtual_stain = crop(virtual_stains[index], crop_size, loc)
else:
phase_image = phase_images[index]
target_stain = target_stains[index]
virtual_stain = virtual_stains[index]

axes[index, 0].imshow(phase_image, cmap="gray")
axes[index, 0].set_title("Phase")
axes[index, 1].imshow(
target_stains[index],
target_stain,
cmap="gray",
vmin=np.percentile(target_stains[index], 1),
vmax=np.percentile(target_stains[index], 99),
vmin=np.percentile(target_stain, 1),
vmax=np.percentile(target_stain, 99),
)
axes[index, 1].set_title("Target Fluorescence ")
axes[index, 2].imshow(
virtual_stains[index],
virtual_stain,
cmap="gray",
vmin=np.percentile(target_stains[index], 1),
vmax=np.percentile(target_stains[index], 99),
vmin=np.percentile(target_stain, 1),
vmax=np.percentile(target_stain, 99),
)
axes[index, 2].set_title("Virtual Stain")
for ax in axes.flatten():
Expand Down Expand Up @@ -519,12 +525,12 @@ def min_max_scale(input):

# Create a dataframe to store the pixel-level metrics.
test_pixel_metrics = pd.DataFrame(
columns=["model", "fov","pearson_nuc", "SSIM_nuc", "psnr_nuc"]
columns=["model", "fov","pearson_nuc", "ssim_nuc", "psnr_nuc"]
)

# Compute the pixel-level metrics.
for i, (target_stain, predicted_stain) in tqdm(enumerate(zip(target_stains, virtual_stains))):
fov = virtual_stain_paths[i].splt("/")[-1].split(".")[0]
fov = str(virtual_stain_paths[i]).split("/")[-1].split(".")[0]
minmax_norm_target = min_max_scale(target_stain)
minmax_norm_predicted = min_max_scale(predicted_stain)

Expand All @@ -544,19 +550,21 @@ def min_max_scale(input):
test_pixel_metrics.loc[len(test_pixel_metrics)] = {
"model": "pix2pixHD",
"fov":fov,
"Pearson_nuc": pearson_nuc,
"SSIM_nuc": ssim_nuc,
"PSNR_nuc": psnr_nuc,
"pearson_nuc": pearson_nuc,
"ssim_nuc": ssim_nuc,
"psnr_nuc": psnr_nuc,
}

test_pixel_metrics.boxplot(
column=["Pearson_nuc", "SSIM_nuc"],
column=["pearson_nuc", "ssim_nuc"],
rot=30,
)
# %%
test_pixel_metrics.boxplot(
column=["psnr_nuc"],
rot=30,
)
# %%
test_pixel_metrics.head()
# %% [markdown]
"""
Expand Down Expand Up @@ -610,7 +618,7 @@ def cellpose_segmentation(prediction:ArrayLike,target:ArrayLike)->Tuple[torch.Sh
segmentation_results = ()

for i, (target_stain, predicted_stain) in tqdm(enumerate(zip(target_stains, virtual_stains))):
fov = virtual_stain_paths[i].splt("/")[-1].split(".")[0]
fov = str(virtual_stain_paths)[i].spilt("/")[-1].split(".")[0]
minmax_norm_target = min_max_scale(target_stain)
minmax_norm_predicted = min_max_scale(predicted_stain)
# Compute the segmentation masks.
Expand Down

0 comments on commit b38d4e0

Please sign in to comment.