Skip to content

Commit

Permalink
Initial version of filter_catalog feature.
Browse files Browse the repository at this point in the history
- We can take a fits file as a config
- We filter objects_ids out of a big dataset based on it
- We also skip filesystem checks if there is enough info in the filter catalog.
- Lacks any unit testing
- Added the prepare verb, but right now it just gives you the dataset object
  when run from a notebook.
  • Loading branch information
mtauraso committed Nov 1, 2024
1 parent c12e7d5 commit fb7a066
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 11 deletions.
99 changes: 91 additions & 8 deletions src/fibad/data_sets/hsc_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import torch
from astropy.io import fits
from astropy.table import Table
from torch.utils.data import Dataset
from torchvision.transforms.v2 import CenterCrop, Compose, Lambda

Expand Down Expand Up @@ -240,6 +241,10 @@ def __getitem__(self, idx: int) -> torch.Tensor:
return self.data[self.indexes[idx]]


dim_dict = dict[str, tuple[int, int]]
files_dict = dict[str, dict[str, str]]


class HSCDataSetContainer(Dataset):
def __init__(self, config):
# TODO: What will be a reasonable set of tranformations?
Expand All @@ -250,12 +255,14 @@ def __init__(self, config):

crop_to = config["data_set"]["crop_to"]
filters = config["data_set"]["filters"]
filter_catalog = config["data_set"]["filter_catalog"]

self._init_from_path(
config["general"]["data_dir"],
transform=transform,
cutout_shape=crop_to if crop_to else None,
filters=filters if filters else None,
filter_catalog=Path(filter_catalog) if filter_catalog else None,
)

def _init_from_path(
Expand All @@ -265,6 +272,7 @@ def _init_from_path(
transform=None,
cutout_shape: Optional[tuple[int, int]] = None,
filters: Optional[list[str]] = None,
filter_catalog: Optional[Path] = None,
):
"""__init__ helper. Initialize an HSC data set from a path. This involves several filesystem scan
operations and will ultimately open and read the header info of every fits file in the given directory
Expand All @@ -284,12 +292,31 @@ def _init_from_path(
cutouts which do not have fits files corresponding to every filter in the list will be dropped
from the data set. Defaults to None. If not provided, the filters available on the filesystem for
the first object in the directory will be used.
filter_catalog: Path, optional
Path to a .fits file which specifies objects and or files to use directly, bypassing the default
of attempting to use every file in the path.
Columns for this fits file are object_id (required), filter (optional), filename (optional), and
dims (optional tuple of x/y pixel size of images).
- Filenames must be relative to the path provided to this function.
- When filters and filenames are both provided, initialization skips a directory listing, which
can provide better performance on large datasets.
- When filters, filenames, and dims are specified we also skip opening the files to get
the dimensions. This can also provide better performance on large datasets.
"""
self.path = path
self.transform = transform

self.files = self._scan_file_names(filters)
self.dims = self._scan_file_dimensions()
self.filter_catalog = self._read_filter_catalog(filter_catalog)
if isinstance(self.filter_catalog, tuple):
self.files = self.filter_catalog[0]
self.dims = self.filter_catalog[1]
print(self.dims)

Check warning on line 313 in src/fibad/data_sets/hsc_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/hsc_data_set.py#L311-L313

Added lines #L311 - L313 were not covered by tests
elif isinstance(self.filter_catalog, dict):
self.files = self.filter_catalog
self.dims = self._scan_file_dimensions()

Check warning on line 316 in src/fibad/data_sets/hsc_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/hsc_data_set.py#L315-L316

Added lines #L315 - L316 were not covered by tests
else:
self.files = self._scan_file_names(filters)
self.dims = self._scan_file_dimensions()

# If no filters provided, we choose the first file in the dict as the prototypical set of filters
# Any objects lacking this full set of filters will be pruned by _prune_objects
Expand All @@ -313,7 +340,7 @@ def _init_from_path(

logger.info(f"HSC Data set loader has {len(self)} objects")

def _scan_file_names(self, filters: Optional[list[str]] = None) -> dict[str, dict[str, str]]:
def _scan_file_names(self, filters: Optional[list[str]] = None) -> files_dict:
"""Class initialization helper
Parameters
Expand All @@ -335,11 +362,17 @@ def _scan_file_names(self, filters: Optional[list[str]] = None) -> dict[str, dic

files = {}
# Go scan the path for object ID's so we have a list.
for filepath in Path(self.path).glob("[0-9]*.fits"):
for filepath in Path(self.path).iterdir():
filename = filepath.name
m = re.match(full_regex, filename)

# Skip files that don't match the pattern.
# If we are filtering based off a user-provided catalog of object ids, Filter out any
# objects_ids not in the catalog. Do this before regex match for speed of discarding
# irrelevant files.
if isinstance(self.filter_catalog, list) and filename[:17] not in self.filter_catalog:
continue

Check warning on line 372 in src/fibad/data_sets/hsc_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/hsc_data_set.py#L372

Added line #L372 was not covered by tests

m = re.match(full_regex, filename)
# Skip files that don't allow us to extract both object_id and filter
if m is None:
continue

Expand All @@ -359,7 +392,57 @@ def _scan_file_names(self, filters: Optional[list[str]] = None) -> dict[str, dic

return files

def _scan_file_dimensions(self) -> dict[str, tuple[int, int]]:
def _read_filter_catalog(
self, filter_catalog_path: Optional[Path]
) -> Optional[Union[list[str], files_dict, tuple[files_dict, dim_dict]]]:
if filter_catalog_path is None:
return None

if not filter_catalog_path.exists():
logger.error(f"Filter catalog file {filter_catalog_path} given in config does not exist.")
return None

Check warning on line 403 in src/fibad/data_sets/hsc_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/hsc_data_set.py#L401-L403

Added lines #L401 - L403 were not covered by tests

table = Table.read(filter_catalog_path, format="fits")
colnames = table.colnames
if "object_id" not in colnames:
logger.error(f"Filter catalog file {filter_catalog_path} has no column object_id")
return None

Check warning on line 409 in src/fibad/data_sets/hsc_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/hsc_data_set.py#L405-L409

Added lines #L405 - L409 were not covered by tests

# We are dealing with just a list of object_ids
if "filter" not in colnames and "filename" not in colnames:
return list(table["object_id"])

Check warning on line 413 in src/fibad/data_sets/hsc_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/hsc_data_set.py#L412-L413

Added lines #L412 - L413 were not covered by tests

# Or a table that lacks both filter and filename
elif "filter" not in colnames or "filename" not in colnames:
msg = f"Filter catalog file {filter_catalog_path} provides one of filters or filenames "
msg += "without the other. Filesystem scan will still occur without both defined."
logger.warning(msg)
return list(set(table["object_id"]))

Check warning on line 420 in src/fibad/data_sets/hsc_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/hsc_data_set.py#L416-L420

Added lines #L416 - L420 were not covered by tests

# We have filter and filename defined so we can assemble the catalog at file level.
filter_catalog = {}
if "dim" in colnames:
dim_catalog = {}

Check warning on line 425 in src/fibad/data_sets/hsc_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/hsc_data_set.py#L423-L425

Added lines #L423 - L425 were not covered by tests

for row in table:
object_id = row["object_id"]
filter = row["filter"]
filename = row["filename"]

Check warning on line 430 in src/fibad/data_sets/hsc_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/hsc_data_set.py#L427-L430

Added lines #L427 - L430 were not covered by tests

if object_id not in filter_catalog:
filter_catalog[object_id] = {}

Check warning on line 433 in src/fibad/data_sets/hsc_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/hsc_data_set.py#L432-L433

Added lines #L432 - L433 were not covered by tests

filter_catalog[object_id][filter] = filename

Check warning on line 435 in src/fibad/data_sets/hsc_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/hsc_data_set.py#L435

Added line #L435 was not covered by tests

# Dimension is optional
if "dim" in colnames:
if object_id not in dim_catalog:
dim_catalog[object_id] = []
dim_catalog[object_id].append(tuple(row["dim"]))

Check warning on line 441 in src/fibad/data_sets/hsc_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/hsc_data_set.py#L438-L441

Added lines #L438 - L441 were not covered by tests

return (filter_catalog, dim_catalog) if "dim" in colnames else filter_catalog

Check warning on line 443 in src/fibad/data_sets/hsc_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/hsc_data_set.py#L443

Added line #L443 was not covered by tests

def _scan_file_dimensions(self) -> dim_dict:
# Scan the filesystem to get the widths and heights of all images into a dict
return {
object_id: [self._fits_file_dims(filepath) for filepath in self._object_files(object_id)]
Expand Down Expand Up @@ -445,7 +528,7 @@ def _check_file_dimensions(self) -> tuple[int, int]:
The minimum width and height in pixels of the entire dataset. In other words: the maximal image
size in pixels that can be generated from ALL cutout images via cropping.
"""
# Find the makximal cutout size that all images can support
# Find the maximal cutout size that all images can support
all_widths = [shape[0] for shape_list in self.dims.values() for shape in shape_list]
cutout_width = np.min(all_widths)

Expand Down
5 changes: 5 additions & 0 deletions src/fibad/downloadCutout/downloadCutout.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from collections.abc import Generator
from typing import IO, Any, Callable, Optional, Union, cast

import numpy as np

__all__ = []


Expand Down Expand Up @@ -762,6 +764,9 @@ def parse_bool(s: Union[str, bool]) -> bool:
if isinstance(s, bool):
return s

if isinstance(s, np.bool):
return s

return {
"false": False,
"f": False,
Expand Down
10 changes: 9 additions & 1 deletion src/fibad/fibad.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Fibad:
CLI functions in fibad_cli are implemented by calling this class
"""

verbs = ["train", "predict", "download"]
verbs = ["train", "predict", "download", "prepare"]

def __init__(self, *, config_file: Union[Path, str] = None, setup_logging: bool = True):
"""Initialize fibad. Always applies the default config, and merges it with any provided config file.
Expand Down Expand Up @@ -177,3 +177,11 @@ def predict(self, **kwargs):
from .predict import run

return run(config=self.config, **kwargs)

def prepare(self, **kwargs):
"""
See Fibad.predict.run()
"""
from .prepare import run

Check warning on line 185 in src/fibad/fibad.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/fibad.py#L185

Added line #L185 was not covered by tests

return run(config=self.config, **kwargs)

Check warning on line 187 in src/fibad/fibad.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/fibad.py#L187

Added line #L187 was not covered by tests
4 changes: 4 additions & 0 deletions src/fibad/fibad_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ crop_to = false
#filters = ["HSC-G", "HSC-R", "HSC-I", "HSC-Z", "HSC-Y"]
filters = false

# A fits file which specifies object IDs to filter a large dataset in [general].data_dir down
# Implementation is dataset class dependent. Default is false meaning now filtering.
filter_catalog = false

[data_loader]
# Default PyTorch DataLoader parameters
batch_size = 32
Expand Down
30 changes: 28 additions & 2 deletions tests/fibad/test_hsc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, test_files: dict):
self.test_files = test_files

mock_paths = [Path(x) for x in list(test_files.keys())]
target = "fibad.data_sets.hsc_data_set.Path.glob"
target = "fibad.data_sets.hsc_data_set.Path.iterdir"
self.patchers.append(mock.patch(target, return_value=mock_paths))

mock_fits_open = mock.Mock(side_effect=self._open_file)
Expand All @@ -53,7 +53,15 @@ def __exit__(self, *exc):
patcher.stop()


def mkconfig(crop_to=False, filters=False, train_size=0.2, test_size=0.6, validate_size=0, seed=False):
def mkconfig(
crop_to=False,
filters=False,
train_size=0.2,
test_size=0.6,
validate_size=0,
seed=False,
filter_catalog=False,
):
"""Makes a configuration that points at nonexistent path so HSCDataSet.__init__ will create an object,
and our FakeFitsFS shim can be called.
"""
Expand All @@ -62,6 +70,7 @@ def mkconfig(crop_to=False, filters=False, train_size=0.2, test_size=0.6, valida
"data_set": {
"crop_to": crop_to,
"filters": filters,
"filter_catalog": filter_catalog,
},
"prepare": {
"seed": seed,
Expand Down Expand Up @@ -564,3 +573,20 @@ def test_split_and_conflicting_datasets():

with pytest.raises(RuntimeError):
a.current_split.logical_and(b.current_split)


def test_filter_catalog(caplog):
"""Test to ensure loading a perfectly regular set of files works"""
caplog.set_level(logging.WARNING)
test_files = generate_files(num_objects=10, num_filters=5, shape=(262, 263))
with FakeFitsFS(test_files):
a = HSCDataSet(mkconfig(), split=None)

# 10 objects should load
assert len(a) == 10

# The number of filters, and image dimensions should be correct
assert a.shape() == (5, 262, 263)

# No warnings should be printed
assert caplog.text == ""

0 comments on commit fb7a066

Please sign in to comment.