diff --git a/src/fibad/data_loaders/hsc_data_loader.py b/src/fibad/data_loaders/hsc_data_loader.py index 6f6849a..1546403 100644 --- a/src/fibad/data_loaders/hsc_data_loader.py +++ b/src/fibad/data_loaders/hsc_data_loader.py @@ -3,7 +3,7 @@ import logging import re from pathlib import Path -from typing import Union +from typing import Optional, Union import numpy as np import torch @@ -45,7 +45,11 @@ def data_set(self): # Because it goes from unbounded NN output space -> [-1,1] with tanh in its decode step. transform = Lambda(lambd=np.tanh) - return HSCDataSet(self.config.get("path", "./data"), transform=transform) + return HSCDataSet( + self.config.get("path", "./data"), + transform=transform, + cutout_shape=self.config.get("crop_to", None), + ) def data_loader(self, data_set): return torch.utils.data.DataLoader( @@ -60,7 +64,9 @@ def shape(self): class HSCDataSet(Dataset): - def __init__(self, path: Union[Path, str], transform=None): + def __init__( + self, path: Union[Path, str], *, transform=None, cutout_shape: Optional[tuple[int, int]] = None + ): """Initialize an HSC data set from a path. This involves several filesystem scan operations and will ultimately open and read the header info of every fits file in the given directory @@ -69,35 +75,41 @@ def __init__(self, path: Union[Path, str], transform=None): path : Union[Path, str] Path or string specifying the directory path to scan. It is expected that all files will be flat in this directory - transform : _type_, optional - _description_, by default None + transform : torchvision.transforms.v2.Transform, optional + Transformation to apply to every image in the dataset, by default None + cutout_shape: tuple[int,int], optional + Forces all cutouts to be a particular pixel size. RuntimeError is raised if this size is larger + than the pixel dimension of any cutout in the dataset. """ self.path = path self.transform = transform - self.files = self._scan_files() + self.files = self._scan_file_names() + self.dims = self._scan_file_dimensions() # We choose the first file in the dict as the prototypical set of filters # Any objects lacking this full set of filters will be pruned by # _prune_objects filters_ref = list(list(self.files.values())[0]) - self.num_filters = len(filters_ref) - self.object_ids = self._prune_objects(self.files, filters_ref) + self.cutout_shape = cutout_shape + + self.object_ids = self._prune_objects(filters_ref) - self.cutout_width, self.cutout_height = self._check_file_dimensions() + if self.cutout_shape is None: + self.cutout_shape = self._check_file_dimensions() # Set up our default transform to center-crop the image to the common size before # Applying any transforms we were passed. - crop = CenterCrop(size=(self.cutout_width, self.cutout_height)) + crop = CenterCrop(size=self.cutout_shape) self.transform = Compose([crop, self.transform]) if self.transform is not None else crop self.tensors = {} logger.info(f"HSC Data set loader has {len(self)} objects") - def _scan_files(self) -> dict[str, dict[str, str]]: + def _scan_file_names(self) -> dict[str, dict[str, str]]: """Class initialization helper Returns @@ -117,13 +129,29 @@ def _scan_files(self) -> dict[str, dict[str, str]]: if files.get(object_id) is None: files[object_id] = {} - files[object_id][filter] = filename + if files[object_id].get(filter) is None: + files[object_id][filter] = filename + else: + msg = f"Duplicate object ID {object_id} detected.\n" + msg += f"File {filename} conflicts with already scanned file {files[object_id][filter]} " + msg += "and will not be included in the data set." + logger.error(msg) return files - def _prune_objects(self, files: dict[str, dict[str, str]], filters_ref: list[str]) -> list[str]: - """Class initialization helper. Prunes files dict (which will be self.files). Removes any objects - which do not ahve all the filters specified in filters_ref + def _scan_file_dimensions(self) -> dict[str, tuple[int, int]]: + # Scan the filesystem to get the widths and heights of all images into a dict + return { + object_id: [self._fits_file_dims(filepath) for filepath in self._object_files(object_id)] + for object_id in self._all_object_ids() + } + + def _prune_objects(self, filters_ref: list[str]) -> list[str]: + """Class initialization helper. Prunes objects from the list of objects. + + 1) Removes any objects which do not have all the filters specified in filters_ref + 2) If a cutout_shape was provided in the constructor, prunes files that are too small + for the chosen cutout size Parameters ---------- @@ -140,37 +168,49 @@ def _prune_objects(self, files: dict[str, dict[str, str]], filters_ref: list[str List of all object IDs which survived the prune. """ filters_ref = sorted(filters_ref) - prune_count = 0 - for object_id, filters in list(files.items()): + self.prune_count = 0 + for object_id, filters in list(self.files.items()): + # Drop objects with missing filters filters = sorted(list(filters)) if filters != filters_ref: - logger.warning( - f"HSCDataSet in {self.path} has the wrong group of filters for object {object_id}." - ) - logger.warning(f"Dropping object {object_id} from the dataset.") + msg = f"HSCDataSet in {self.path} has the wrong group of filters for object {object_id}." + self._prune_object(object_id, msg) logger.info(f"Filters for object {object_id} were {filters}") logger.debug(f"Reference filters were {filters_ref}") - prune_count += 1 - # Remove any object IDs for which we don't have all the filters - del files[object_id] - # Dump all object IDs into a list so there is an explicit indexing/ordering convention - # valid for the lifetime of this object. - object_ids = list(files) + # Drop objects that can't meet the coutout size provided + elif self.cutout_shape is not None: + for shape in self.dims[object_id]: + if shape[0] < self.cutout_shape[0] or shape[1] < self.cutout_shape[1]: + msg = f"A file for object {object_id} has shape ({shape[1]}px, {shape[1]}px)" + msg += " this is too small for the given cutout size of " + msg += f"({self.cutout_shape[0]}px, {self.cutout_shape[1]}px)" + self._prune_object(object_id, msg) + break # Log about the pruning process - pre_prune_object_count = len(object_ids) + prune_count - prune_fraction = prune_count / pre_prune_object_count + pre_prune_object_count = len(self.files) + self.prune_count + prune_fraction = self.prune_count / pre_prune_object_count if prune_fraction > 0.05: logger.error("Greater than 5% of objects in the data directory were pruned.") elif prune_fraction > 0.01: logger.warning("Greater than 1% of objects in the data directory were pruned.") - logger.info(f"Pruned {prune_count} out of {pre_prune_object_count} objects") + logger.info(f"Pruned {self.prune_count} out of {pre_prune_object_count} objects") + + def _prune_object(self, object_id, reason: str): + logger.warning(reason) + logger.warning(f"Dropping object {object_id} from the dataset") - return object_ids + del self.files[object_id] + del self.dims[object_id] + self.prune_count += 1 + + def _fits_file_dims(self, filepath): + with fits.open(filepath) as hdul: + return hdul[1].shape def _check_file_dimensions(self) -> tuple[int, int]: - """Class initialization helper. Scan all files to determine the minimal pixel size of images + """Class initialization helper. Find the maximal pixel size that all images can support It is assumed that all the cutouts will be of very similar size; however, HSC's cutout server does not return exactly the same number of pixels for every query, even when it @@ -187,15 +227,11 @@ def _check_file_dimensions(self) -> tuple[int, int]: The minimum width and height in pixels of the entire dataset. In other words: the maximal image size in pixels that can be generated from ALL cutout images via cropping. """ - all_widths, all_heights = ([], []) - - for filepath in self._all_files(): - with fits.open(filepath) as hdul: - width, height = hdul[1].shape - all_widths.append(width) - all_heights.append(height) - + # Find the makximal cutout size that all images can support + all_widths = [shape[0] for shape_list in self.dims.values() for shape in shape_list] cutout_width = np.min(all_widths) + + all_heights = [shape[1] for shape_list in self.dims.values() for shape in shape_list] cutout_height = np.min(all_heights) if ( @@ -204,12 +240,15 @@ def _check_file_dimensions(self) -> tuple[int, int]: or np.abs(np.max(all_widths) - np.mean(all_widths)) > 1 or np.abs(np.max(all_heights) - np.mean(all_heights)) > 1 ): - logger.warning("Some images differ from the mean width or height of all images by more than 1px") - logger.warning(f"Images will be cropped to ({cutout_width}px, {cutout_height}px)") - min_width_file = self._get_file(np.argmin(all_widths)) - logger.warning(f"See {min_width_file} for an example image of width {cutout_width}px") - min_height_file = self._get_file(np.argmin(all_heights)) - logger.warning(f"See {min_height_file} for an example image of height {cutout_height}px") + msg = "Some images differ from the mean width or height of all images by more than 1px\n" + msg += f"Images will be cropped to ({cutout_width}px, {cutout_height}px)\n" + try: + min_width_file = self._get_file(np.argmin(all_widths)) + min_height_file = self._get_file(np.argmin(all_heights)) + msg += f"See {min_width_file} for an example image of width {cutout_width}px\n" + msg += f"See {min_height_file} for an example image of height {cutout_height}px" + finally: + logger.warning(msg) return cutout_width, cutout_height @@ -224,7 +263,7 @@ def shape(self) -> tuple[int, int, int]: The second index is the width of each image The third index is the height of each image """ - return (self.num_filters, self.cutout_width, self.cutout_height) + return (self.num_filters, self.cutout_shape[0], self.cutout_shape[1]) def __len__(self) -> int: """Returns number of objects in this loader @@ -234,18 +273,31 @@ def __len__(self) -> int: int number of objects in this data loader """ - return len(self.object_ids) + return len(self.files) def __getitem__(self, idx: int) -> torch.Tensor: - if idx >= len(self.object_ids) or idx < 0: + if idx >= len(self.files) or idx < 0: raise IndexError # Use the list of object IDs for explicit indexing - object_id = self.object_ids[idx] + object_id = list(self.files.keys())[idx] + + return self._object_id_to_tensor(object_id) + + def __contains__(self, object_id: str) -> bool: + """Allows you to do `object_id in dataset` queries - tensor = self._object_id_to_tensor(object_id) + Parameters + ---------- + object_id : str + The object ID you'd like to know if is in the dataset - return tensor + Returns + ------- + bool + True of the object_id given is in the data set + """ + return object_id in list(self.files.keys()) and object_id in list(self.dims.keys()) def _get_file(self, index: int) -> Path: """Private indexing method across all files. @@ -269,32 +321,43 @@ def _get_file(self, index: int) -> Path: Path The path to the file """ - object_id = self.object_ids[int(index / self.num_filters)] + object_index = int(index / self.num_filters) + object_id = list(self.files.keys())[object_index] filters = self.files[object_id] filter_names = sorted(list(filters)) filter = filter_names[index % self.num_filters] return self._file_to_path(filters[filter]) - def _all_files(self) -> Path: + def _all_object_ids(self): + """Private read-only iterator over all object_ids that enforces a strict total order across + objects. Will not work prior to self.files initialization in __init__ + + Yields + ------ + Iterator[str] + Object IDs currently in the dataset + """ + for object_id in self.files: + yield object_id + + def _all_files(self): """ Private read-only iterator over all files that enforces a strict total order across - objects and filters. Will not work prior to self.object_ids, self.files, and self.path - initialization in __init__ + objects and filters. Will not work prior to self.files, and self.path initialization in __init__ Yields ------ Path The path to the file. """ - for object_id in self.object_ids: + for object_id in self._all_object_ids(): for filename in self._object_files(object_id): yield filename - def _object_files(self, object_id) -> Path: + def _object_files(self, object_id): """ Private read-only iterator over all files for a given object. This enforces a strict total order - across filters. Will not work prior to self.object_ids, self.files, and self.path initialization - in __init__ + across filters. Will not work prior to self.files, and self.path initialization in __init__ Yields ------ @@ -331,8 +394,8 @@ def _file_to_path(self, filename: str) -> Path: # # For now we just do it the naive way def _object_id_to_tensor(self, object_id: str) -> torch.Tensor: - """Converts an object_id to a pytorch tensor with dimenstions (self.num_filters, self.cutout_width, - self.cutout_height). This is done by reading the file and slicing away any excess pixels at the + """Converts an object_id to a pytorch tensor with dimenstions (self.num_filters, self.cutout_shape[0], + self.cutout_shape[1]). This is done by reading the file and slicing away any excess pixels at the far corners of the image from (0,0). The current implementation reads the files once the first time they are accessed, and then @@ -346,7 +409,7 @@ def _object_id_to_tensor(self, object_id: str) -> torch.Tensor: Returns ------- torch.Tensor - A tensor with dimension (self.num_filters, self.cutout_width, self.cutout_height) + A tensor with dimension (self.num_filters, self.cutout_shape[0], self.cutout_shape[1]) """ data_torch = self.tensors.get(object_id, None) if data_torch is not None: diff --git a/src/fibad/fibad_default_config.toml b/src/fibad/fibad_default_config.toml index 28a12f0..a4b7195 100644 --- a/src/fibad/fibad_default_config.toml +++ b/src/fibad/fibad_default_config.toml @@ -59,6 +59,13 @@ name = "HSCDataLoader" # Directory path where the data is stored path = "./data" +# Pixel dimensions used to crop all images prior to loading. Will prune any images that are too small. +# +# If not provided, the default is to scan the directory for the smallest dimensioned files, and use +# those pixel dimensions as the crop size. +# +#crop_to = [100,100] + # Default PyTorch DataLoader parameters batch_size = 500 shuffle = true diff --git a/src/fibad/models/example_autoencoder.py b/src/fibad/models/example_autoencoder.py index b995311..ae4ae05 100644 --- a/src/fibad/models/example_autoencoder.py +++ b/src/fibad/models/example_autoencoder.py @@ -108,7 +108,7 @@ def train(self, trainloader, device=None): torch.set_grad_enabled(True) - print(f"len(trainloder) = {len(trainloader)}") + # print(f"len(trainloder) = {len(trainloader)}") for epoch in range(self.config.get("epochs", 2)): running_loss = 0.0 for batch_num, data in enumerate(trainloader, 0): diff --git a/tests/fibad/test_hsc_dataset.py b/tests/fibad/test_hsc_dataset.py index 7e2ba28..5ce74c9 100644 --- a/tests/fibad/test_hsc_dataset.py +++ b/tests/fibad/test_hsc_dataset.py @@ -51,7 +51,9 @@ def __exit__(self, *exc): patcher.stop() -def generate_files(num_objects=10, num_filters=5, shape=(100, 100), offset=0) -> dict: +def generate_files( + num_objects=10, num_filters=5, shape=(100, 100), offset=0, infill_str="all_filters" +) -> dict: """Generates a dictionary to pass in to FakeFitsFS. This generates a dict from filename->shape tuple for a set of uniform fake fits files @@ -72,6 +74,8 @@ def generate_files(num_objects=10, num_filters=5, shape=(100, 100), offset=0) -> What are the dimensions of the image in each fits file, by default (100,100) offset : int, optional What is the first object_id to start with, by default 0 + infill_str: str, optional + What to put in the fake filename in between the object ID and filter name. By default "all_filters" Returns ------- @@ -82,7 +86,7 @@ def generate_files(num_objects=10, num_filters=5, shape=(100, 100), offset=0) -> test_files = {} for object_id in range(offset, num_objects + offset): for filter in filters: - test_files[f"{object_id:017d}_all_filters_{filter}.fits"] = shape + test_files[f"{object_id:017d}_{infill_str}_{filter}.fits"] = shape return test_files @@ -104,6 +108,34 @@ def test_load(caplog): assert caplog.text == "" +def test_load_duplicate(caplog): + """Test to ensure duplicate fits files that reference the same object id and filter create the + appropriate error messages. + """ + caplog.set_level(logging.ERROR) + test_files = generate_files(num_objects=10, num_filters=5, shape=(262, 263)) + duplicate_files = generate_files(num_objects=10, num_filters=5, shape=(262, 263), infill_str="duplicate") + test_files.update(duplicate_files) + with FakeFitsFS(test_files): + a = HSCDataSet("thispathdoesnotexist") + + # Only 10 objects should load + assert len(a) == 10 + + # The number of filters, and image dimensions should be correct + assert a.shape() == (5, 262, 263) + + # We should get duplicate object errors + assert "Duplicate object ID" in caplog.text + + # We should get errors that include the duplicate filenames + assert "_duplicate_" in caplog.text + + # The duplicate files should not be in the data set + for filepath in a._all_files(): + assert "_duplicate_" not in str(filepath) + + def test_prune_warn_1_percent(caplog): """Test to ensure when >1% of loaded objects are missing a filter, that is a warning and that the resulting dataset drops the objects that are missing filters @@ -122,11 +154,14 @@ def test_prune_warn_1_percent(caplog): assert len(a) == 98 # Object 2 should not be loaded - assert "00000000000000101" not in a.object_ids + assert "00000000000000101" not in a # We should Error log because greater than 5% of the objects were pruned assert "Greater than 1% of objects in the data directory were pruned." in caplog.text + # We should warn that we dropped an object explicitly + assert "Dropping object" in caplog.text + def test_prune_error_5_percent(caplog): """Test to ensure when >5% of loaded objects are missing a filter, that is an error @@ -146,7 +181,7 @@ def test_prune_error_5_percent(caplog): assert len(a) == 18 # Object 20 should not be loaded - assert "00000000000000020" not in a.object_ids + assert "00000000000000020" not in a # We should Error log because greater than 5% of the objects were pruned assert "Greater than 5% of objects in the data directory were pruned." in caplog.text @@ -200,7 +235,7 @@ def test_crop_warn_2px_larger(caplog): assert len(a) == 70 assert a.shape() == (5, 99, 99) - # No warnings should be printed since we're within 1px of the mean size + # We should warn that images differ assert "Some images differ" in caplog.text @@ -226,5 +261,27 @@ def test_crop_warn_2px_smaller(caplog): assert len(a) == 70 assert a.shape() == (5, 98, 98) - # No warnings should be printed since we're within 1px of the mean size + # We should warn that images differ assert "Some images differ" in caplog.text + + +def test_prune_size(caplog): + """Test to ensure images that are too small will be pruned from the data set when a custom size is + passed.""" + caplog.set_level(logging.WARNING) + test_files = {} + test_files.update(generate_files(num_objects=10, num_filters=5, shape=(100, 100), offset=0)) + # Add some images with dimensions 1 px larger + test_files.update(generate_files(num_objects=10, num_filters=5, shape=(101, 101), offset=20)) + # Add some images with dimensions 2 px smaller + test_files.update(generate_files(num_objects=10, num_filters=5, shape=(98, 98), offset=30)) + + with FakeFitsFS(test_files): + a = HSCDataSet("thispathdoesnotexist", cutout_shape=(99, 99)) + + assert len(a) == 20 + assert a.shape() == (5, 99, 99) + + # We should warn that we are dropping objects and the reason + assert "Dropping object" in caplog.text + assert "too small" in caplog.text