Skip to content

Commit

Permalink
Merge pull request #49 from lincc-frameworks/issue/35/cutout-interfac…
Browse files Browse the repository at this point in the history
…e-cleanup

Adding user-specified cutout crop size.
  • Loading branch information
mtauraso authored Aug 29, 2024
2 parents 79293c8 + 9cd3680 commit f5b57da
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 70 deletions.
189 changes: 126 additions & 63 deletions src/fibad/data_loaders/hsc_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import re
from pathlib import Path
from typing import Union
from typing import Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -45,7 +45,11 @@ def data_set(self):
# Because it goes from unbounded NN output space -> [-1,1] with tanh in its decode step.
transform = Lambda(lambd=np.tanh)

return HSCDataSet(self.config.get("path", "./data"), transform=transform)
return HSCDataSet(
self.config.get("path", "./data"),
transform=transform,
cutout_shape=self.config.get("crop_to", None),
)

def data_loader(self, data_set):
return torch.utils.data.DataLoader(
Expand All @@ -60,7 +64,9 @@ def shape(self):


class HSCDataSet(Dataset):
def __init__(self, path: Union[Path, str], transform=None):
def __init__(
self, path: Union[Path, str], *, transform=None, cutout_shape: Optional[tuple[int, int]] = None
):
"""Initialize an HSC data set from a path. This involves several filesystem scan operations and will
ultimately open and read the header info of every fits file in the given directory
Expand All @@ -69,35 +75,41 @@ def __init__(self, path: Union[Path, str], transform=None):
path : Union[Path, str]
Path or string specifying the directory path to scan. It is expected that all files will
be flat in this directory
transform : _type_, optional
_description_, by default None
transform : torchvision.transforms.v2.Transform, optional
Transformation to apply to every image in the dataset, by default None
cutout_shape: tuple[int,int], optional
Forces all cutouts to be a particular pixel size. RuntimeError is raised if this size is larger
than the pixel dimension of any cutout in the dataset.
"""
self.path = path
self.transform = transform

self.files = self._scan_files()
self.files = self._scan_file_names()
self.dims = self._scan_file_dimensions()

# We choose the first file in the dict as the prototypical set of filters
# Any objects lacking this full set of filters will be pruned by
# _prune_objects
filters_ref = list(list(self.files.values())[0])

self.num_filters = len(filters_ref)

self.object_ids = self._prune_objects(self.files, filters_ref)
self.cutout_shape = cutout_shape

self.object_ids = self._prune_objects(filters_ref)

self.cutout_width, self.cutout_height = self._check_file_dimensions()
if self.cutout_shape is None:
self.cutout_shape = self._check_file_dimensions()

# Set up our default transform to center-crop the image to the common size before
# Applying any transforms we were passed.
crop = CenterCrop(size=(self.cutout_width, self.cutout_height))
crop = CenterCrop(size=self.cutout_shape)
self.transform = Compose([crop, self.transform]) if self.transform is not None else crop

self.tensors = {}

logger.info(f"HSC Data set loader has {len(self)} objects")

def _scan_files(self) -> dict[str, dict[str, str]]:
def _scan_file_names(self) -> dict[str, dict[str, str]]:
"""Class initialization helper
Returns
Expand All @@ -117,13 +129,29 @@ def _scan_files(self) -> dict[str, dict[str, str]]:
if files.get(object_id) is None:
files[object_id] = {}

files[object_id][filter] = filename
if files[object_id].get(filter) is None:
files[object_id][filter] = filename
else:
msg = f"Duplicate object ID {object_id} detected.\n"
msg += f"File {filename} conflicts with already scanned file {files[object_id][filter]} "
msg += "and will not be included in the data set."
logger.error(msg)

return files

def _prune_objects(self, files: dict[str, dict[str, str]], filters_ref: list[str]) -> list[str]:
"""Class initialization helper. Prunes files dict (which will be self.files). Removes any objects
which do not ahve all the filters specified in filters_ref
def _scan_file_dimensions(self) -> dict[str, tuple[int, int]]:
# Scan the filesystem to get the widths and heights of all images into a dict
return {
object_id: [self._fits_file_dims(filepath) for filepath in self._object_files(object_id)]
for object_id in self._all_object_ids()
}

def _prune_objects(self, filters_ref: list[str]) -> list[str]:
"""Class initialization helper. Prunes objects from the list of objects.
1) Removes any objects which do not have all the filters specified in filters_ref
2) If a cutout_shape was provided in the constructor, prunes files that are too small
for the chosen cutout size
Parameters
----------
Expand All @@ -140,37 +168,49 @@ def _prune_objects(self, files: dict[str, dict[str, str]], filters_ref: list[str
List of all object IDs which survived the prune.
"""
filters_ref = sorted(filters_ref)
prune_count = 0
for object_id, filters in list(files.items()):
self.prune_count = 0
for object_id, filters in list(self.files.items()):
# Drop objects with missing filters
filters = sorted(list(filters))
if filters != filters_ref:
logger.warning(
f"HSCDataSet in {self.path} has the wrong group of filters for object {object_id}."
)
logger.warning(f"Dropping object {object_id} from the dataset.")
msg = f"HSCDataSet in {self.path} has the wrong group of filters for object {object_id}."
self._prune_object(object_id, msg)
logger.info(f"Filters for object {object_id} were {filters}")
logger.debug(f"Reference filters were {filters_ref}")
prune_count += 1
# Remove any object IDs for which we don't have all the filters
del files[object_id]

# Dump all object IDs into a list so there is an explicit indexing/ordering convention
# valid for the lifetime of this object.
object_ids = list(files)
# Drop objects that can't meet the coutout size provided
elif self.cutout_shape is not None:
for shape in self.dims[object_id]:
if shape[0] < self.cutout_shape[0] or shape[1] < self.cutout_shape[1]:
msg = f"A file for object {object_id} has shape ({shape[1]}px, {shape[1]}px)"
msg += " this is too small for the given cutout size of "
msg += f"({self.cutout_shape[0]}px, {self.cutout_shape[1]}px)"
self._prune_object(object_id, msg)
break

# Log about the pruning process
pre_prune_object_count = len(object_ids) + prune_count
prune_fraction = prune_count / pre_prune_object_count
pre_prune_object_count = len(self.files) + self.prune_count
prune_fraction = self.prune_count / pre_prune_object_count
if prune_fraction > 0.05:
logger.error("Greater than 5% of objects in the data directory were pruned.")
elif prune_fraction > 0.01:
logger.warning("Greater than 1% of objects in the data directory were pruned.")
logger.info(f"Pruned {prune_count} out of {pre_prune_object_count} objects")
logger.info(f"Pruned {self.prune_count} out of {pre_prune_object_count} objects")

def _prune_object(self, object_id, reason: str):
logger.warning(reason)
logger.warning(f"Dropping object {object_id} from the dataset")

return object_ids
del self.files[object_id]
del self.dims[object_id]
self.prune_count += 1

def _fits_file_dims(self, filepath):
with fits.open(filepath) as hdul:
return hdul[1].shape

def _check_file_dimensions(self) -> tuple[int, int]:
"""Class initialization helper. Scan all files to determine the minimal pixel size of images
"""Class initialization helper. Find the maximal pixel size that all images can support
It is assumed that all the cutouts will be of very similar size; however, HSC's cutout
server does not return exactly the same number of pixels for every query, even when it
Expand All @@ -187,15 +227,11 @@ def _check_file_dimensions(self) -> tuple[int, int]:
The minimum width and height in pixels of the entire dataset. In other words: the maximal image
size in pixels that can be generated from ALL cutout images via cropping.
"""
all_widths, all_heights = ([], [])

for filepath in self._all_files():
with fits.open(filepath) as hdul:
width, height = hdul[1].shape
all_widths.append(width)
all_heights.append(height)

# Find the makximal cutout size that all images can support
all_widths = [shape[0] for shape_list in self.dims.values() for shape in shape_list]
cutout_width = np.min(all_widths)

all_heights = [shape[1] for shape_list in self.dims.values() for shape in shape_list]
cutout_height = np.min(all_heights)

if (
Expand All @@ -204,12 +240,15 @@ def _check_file_dimensions(self) -> tuple[int, int]:
or np.abs(np.max(all_widths) - np.mean(all_widths)) > 1
or np.abs(np.max(all_heights) - np.mean(all_heights)) > 1
):
logger.warning("Some images differ from the mean width or height of all images by more than 1px")
logger.warning(f"Images will be cropped to ({cutout_width}px, {cutout_height}px)")
min_width_file = self._get_file(np.argmin(all_widths))
logger.warning(f"See {min_width_file} for an example image of width {cutout_width}px")
min_height_file = self._get_file(np.argmin(all_heights))
logger.warning(f"See {min_height_file} for an example image of height {cutout_height}px")
msg = "Some images differ from the mean width or height of all images by more than 1px\n"
msg += f"Images will be cropped to ({cutout_width}px, {cutout_height}px)\n"
try:
min_width_file = self._get_file(np.argmin(all_widths))
min_height_file = self._get_file(np.argmin(all_heights))
msg += f"See {min_width_file} for an example image of width {cutout_width}px\n"
msg += f"See {min_height_file} for an example image of height {cutout_height}px"
finally:
logger.warning(msg)

return cutout_width, cutout_height

Expand All @@ -224,7 +263,7 @@ def shape(self) -> tuple[int, int, int]:
The second index is the width of each image
The third index is the height of each image
"""
return (self.num_filters, self.cutout_width, self.cutout_height)
return (self.num_filters, self.cutout_shape[0], self.cutout_shape[1])

def __len__(self) -> int:
"""Returns number of objects in this loader
Expand All @@ -234,18 +273,31 @@ def __len__(self) -> int:
int
number of objects in this data loader
"""
return len(self.object_ids)
return len(self.files)

def __getitem__(self, idx: int) -> torch.Tensor:
if idx >= len(self.object_ids) or idx < 0:
if idx >= len(self.files) or idx < 0:
raise IndexError

# Use the list of object IDs for explicit indexing
object_id = self.object_ids[idx]
object_id = list(self.files.keys())[idx]

return self._object_id_to_tensor(object_id)

def __contains__(self, object_id: str) -> bool:
"""Allows you to do `object_id in dataset` queries
tensor = self._object_id_to_tensor(object_id)
Parameters
----------
object_id : str
The object ID you'd like to know if is in the dataset
return tensor
Returns
-------
bool
True of the object_id given is in the data set
"""
return object_id in list(self.files.keys()) and object_id in list(self.dims.keys())

def _get_file(self, index: int) -> Path:
"""Private indexing method across all files.
Expand All @@ -269,32 +321,43 @@ def _get_file(self, index: int) -> Path:
Path
The path to the file
"""
object_id = self.object_ids[int(index / self.num_filters)]
object_index = int(index / self.num_filters)
object_id = list(self.files.keys())[object_index]
filters = self.files[object_id]
filter_names = sorted(list(filters))
filter = filter_names[index % self.num_filters]
return self._file_to_path(filters[filter])

def _all_files(self) -> Path:
def _all_object_ids(self):
"""Private read-only iterator over all object_ids that enforces a strict total order across
objects. Will not work prior to self.files initialization in __init__
Yields
------
Iterator[str]
Object IDs currently in the dataset
"""
for object_id in self.files:
yield object_id

def _all_files(self):
"""
Private read-only iterator over all files that enforces a strict total order across
objects and filters. Will not work prior to self.object_ids, self.files, and self.path
initialization in __init__
objects and filters. Will not work prior to self.files, and self.path initialization in __init__
Yields
------
Path
The path to the file.
"""
for object_id in self.object_ids:
for object_id in self._all_object_ids():
for filename in self._object_files(object_id):
yield filename

def _object_files(self, object_id) -> Path:
def _object_files(self, object_id):
"""
Private read-only iterator over all files for a given object. This enforces a strict total order
across filters. Will not work prior to self.object_ids, self.files, and self.path initialization
in __init__
across filters. Will not work prior to self.files, and self.path initialization in __init__
Yields
------
Expand Down Expand Up @@ -331,8 +394,8 @@ def _file_to_path(self, filename: str) -> Path:
#
# For now we just do it the naive way
def _object_id_to_tensor(self, object_id: str) -> torch.Tensor:
"""Converts an object_id to a pytorch tensor with dimenstions (self.num_filters, self.cutout_width,
self.cutout_height). This is done by reading the file and slicing away any excess pixels at the
"""Converts an object_id to a pytorch tensor with dimenstions (self.num_filters, self.cutout_shape[0],
self.cutout_shape[1]). This is done by reading the file and slicing away any excess pixels at the
far corners of the image from (0,0).
The current implementation reads the files once the first time they are accessed, and then
Expand All @@ -346,7 +409,7 @@ def _object_id_to_tensor(self, object_id: str) -> torch.Tensor:
Returns
-------
torch.Tensor
A tensor with dimension (self.num_filters, self.cutout_width, self.cutout_height)
A tensor with dimension (self.num_filters, self.cutout_shape[0], self.cutout_shape[1])
"""
data_torch = self.tensors.get(object_id, None)
if data_torch is not None:
Expand Down
7 changes: 7 additions & 0 deletions src/fibad/fibad_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ name = "HSCDataLoader"
# Directory path where the data is stored
path = "./data"

# Pixel dimensions used to crop all images prior to loading. Will prune any images that are too small.
#
# If not provided, the default is to scan the directory for the smallest dimensioned files, and use
# those pixel dimensions as the crop size.
#
#crop_to = [100,100]

# Default PyTorch DataLoader parameters
batch_size = 500
shuffle = true
Expand Down
2 changes: 1 addition & 1 deletion src/fibad/models/example_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def train(self, trainloader, device=None):

torch.set_grad_enabled(True)

print(f"len(trainloder) = {len(trainloader)}")
# print(f"len(trainloder) = {len(trainloader)}")
for epoch in range(self.config.get("epochs", 2)):
running_loss = 0.0
for batch_num, data in enumerate(trainloader, 0):
Expand Down
Loading

0 comments on commit f5b57da

Please sign in to comment.