From b390d589936017dea744662ab764fe4e199ebb29 Mon Sep 17 00:00:00 2001 From: Michael Tauraso Date: Fri, 8 Nov 2024 15:30:54 -0800 Subject: [PATCH] Rebuild a fits manifest from an HSC data directory. - Added a new verb rebuild_manifest - When run with the HSC dataset class this verb will: 0) Scan the data directory and ingest HSC cutout files 1) Read in the original catalog file configured for download for metadata 2) Write out rebuilt_manifest.fits in the data directory - Fixed up config resolution so that fibad_config.toml in the cwd works again for CLI invocations. --- src/fibad/config_utils.py | 63 +++++++------ src/fibad/data_sets/hsc_data_set.py | 137 +++++++++++++++++++++++++++- src/fibad/download.py | 7 +- src/fibad/fibad.py | 16 +++- 4 files changed, 180 insertions(+), 43 deletions(-) diff --git a/src/fibad/config_utils.py b/src/fibad/config_utils.py index a918039..0962871 100644 --- a/src/fibad/config_utils.py +++ b/src/fibad/config_utils.py @@ -76,15 +76,15 @@ def __init__( runtime_config_filepath: Union[Path, str] = None, default_config_filepath: Union[Path, str] = DEFAULT_CONFIG_FILEPATH, ): - self.fibad_default_config = self._read_runtime_config(default_config_filepath) + self.fibad_default_config = ConfigManager._read_runtime_config(default_config_filepath) - self.runtime_config_filepath = runtime_config_filepath - if self.runtime_config_filepath is None: + self.runtime_config_filepath = ConfigManager.resolve_runtime_config(runtime_config_filepath) + if self.runtime_config_filepath is DEFAULT_CONFIG_FILEPATH: self.user_specific_config = ConfigDict() else: - self.user_specific_config = self._read_runtime_config(self.runtime_config_filepath) + self.user_specific_config = ConfigManager._read_runtime_config(self.runtime_config_filepath) - self.external_library_config_paths = self._find_external_library_default_config_paths( + self.external_library_config_paths = ConfigManager._find_external_library_default_config_paths( self.user_specific_config ) @@ -93,7 +93,7 @@ def __init__( self.config = self.merge_configs(self.overall_default_config, self.user_specific_config) if not self.config["general"]["dev_mode"]: - self._validate_runtime_config(self.config, self.overall_default_config) + ConfigManager._validate_runtime_config(self.config, self.overall_default_config) @staticmethod def _read_runtime_config(config_filepath: Union[Path, str] = DEFAULT_CONFIG_FILEPATH) -> ConfigDict: @@ -232,38 +232,37 @@ def _validate_runtime_config(runtime_config: ConfigDict, default_config: ConfigD raise RuntimeError(msg) ConfigManager._validate_runtime_config(runtime_config[key], default_config[key]) + @staticmethod + def resolve_runtime_config(runtime_config_filepath: Union[Path, str, None] = None) -> Path: + """Resolve a user-supplied runtime config to where we will actually pull config from. -def resolve_runtime_config(runtime_config_filepath: Union[Path, str, None] = None) -> Path: - """Resolve a user-supplied runtime config to where we will actually pull config from. - - 1) If a runtime config file is specified, we will use that file - 2) If no file is specified and there is a file named "fibad_config.toml" in the cwd we will use that file - 3) If no file is specified and there is no file named "fibad_config.toml" in the current working directory - we will exclusively work off the configuration defaults in the packaged "fibad_default_config.toml" - file. + 1) If a runtime config file is specified, we will use that file. + 2) If no file is specified and there is a file named "fibad_config.toml" in the cwd we will use it. + 3) If no file is specified and there is no file named "fibad_config.toml" in the cwd we will + exclusively work off the configuration defaults in the packaged "fibad_default_config.toml" file. - Parameters - ---------- - runtime_config_filepath : Union[Path, str, None], optional - Location of the supplied config file, by default None + Parameters + ---------- + runtime_config_filepath : Union[Path, str, None], optional + Location of the supplied config file, by default None - Returns - ------- - Path - Path to the configuration file ultimately used for config resolution. When we fall back to the - package supplied default config file, the Path to that file is returned. - """ - if isinstance(runtime_config_filepath, str): - runtime_config_filepath = Path(runtime_config_filepath) + Returns + ------- + Path + Path to the configuration file ultimately used for config resolution. When we fall back to the + package supplied default config file, the Path to that file is returned. + """ + if isinstance(runtime_config_filepath, str): + runtime_config_filepath = Path(runtime_config_filepath) - # If a named config exists in cwd, and no config specified on cmdline, use cwd. - if runtime_config_filepath is None and DEFAULT_USER_CONFIG_FILEPATH.exists(): - runtime_config_filepath = DEFAULT_USER_CONFIG_FILEPATH + # If a named config exists in cwd, and no config specified on cmdline, use cwd. + if runtime_config_filepath is None and DEFAULT_USER_CONFIG_FILEPATH.exists(): + runtime_config_filepath = DEFAULT_USER_CONFIG_FILEPATH - if runtime_config_filepath is None: - runtime_config_filepath = DEFAULT_CONFIG_FILEPATH + if runtime_config_filepath is None: + runtime_config_filepath = DEFAULT_CONFIG_FILEPATH - return runtime_config_filepath + return runtime_config_filepath def create_results_dir(config: ConfigDict, postfix: Union[Path, str]) -> Path: diff --git a/src/fibad/data_sets/hsc_data_set.py b/src/fibad/data_sets/hsc_data_set.py index 7df1bd8..1c2c327 100644 --- a/src/fibad/data_sets/hsc_data_set.py +++ b/src/fibad/data_sets/hsc_data_set.py @@ -13,6 +13,17 @@ from torch.utils.data import Dataset from torchvision.transforms.v2 import CenterCrop, Compose, Lambda +from fibad.download import Downloader +from fibad.downloadCutout.downloadCutout import ( + parse_bool, + parse_degree, + parse_latitude, + parse_longitude, + parse_rerun, + parse_tract_opt, + parse_type, +) + from .data_set_registry import fibad_data_set logger = logging.getLogger(__name__) @@ -94,6 +105,9 @@ def __getitem__(self, idx: int) -> torch.Tensor: def __len__(self) -> int: return len(self.current_split) + def rebuild_manifest(self, config): + return self.container._rebuild_manifest(config) + class HSCDataSetSplit(Dataset): def __init__( @@ -553,6 +567,88 @@ def _check_file_dimensions(self) -> tuple[int, int]: return cutout_width, cutout_height + def _rebuild_manifest(self, config): + if self.filter_catalog: + raise RuntimeError("Cannot rebuild manifest. Set the filter_catalog=false and rerun") + + logger.info("Reading in catalog file... ") + location_table = Downloader.filterfits( + Path(config["download"]["fits_file"]).resolve(), ["object_id", "ra", "dec"] + ) + + obj_to_ra = { + str(location_table["object_id"][index]): location_table["ra"][index] + for index in range(len(location_table)) + } + obj_to_dec = { + str(location_table["object_id"][index]): location_table["dec"][index] + for index in range(len(location_table)) + } + + del location_table + + logger.info("Assembling Manifest...") + + # These are the column names expected in a manifest file by the downloader + column_names = Downloader.MANIFEST_COLUMN_NAMES + columns = {column_name: [] for column_name in column_names} + + # These we vary every object and must be implemented below + dynamic_column_names = ["object_id", "filter", "dim", "tract", "ra", "dec", "filename"] + # These are pulled from config ("sw", "sh", "rerun", "type", "image", "mask", and "variance") + static_column_names = [name for name in column_names if name not in dynamic_column_names] + + # Check that all column names we need for a manifest are either in static or dynamic columns + for column_name in column_names: + if column_name not in static_column_names and column_name not in dynamic_column_names: + raise RuntimeError(f"Error Assembling manifest {column_name} not implemented") + + static_values = { + "sw": parse_degree(config["download"]["sw"]), + "sh": parse_degree(config["download"]["sh"]), + "rerun": parse_rerun(config["download"]["rerun"]), + "type": parse_type(config["download"]["type"]), + "image": parse_bool(config["download"]["image"]), + "mask": parse_bool(config["download"]["mask"]), + "variance": parse_bool(config["download"]["variance"]), + } + + for object_id, filter, filename, dim in self._all_files_full(): + for static_col in static_column_names: + columns[static_col].append(static_values[static_col]) + + for dynamic_col in dynamic_column_names: + if dynamic_col == "object_id": + columns[dynamic_col].append(int(object_id)) + elif dynamic_col == "filter": + columns[dynamic_col].append(filter) + elif dynamic_col == "dim": + columns[dynamic_col].append(dim) + elif dynamic_col == "tract": + # There's value in pulling tract from the filename rather than the download catalog + # in case The catalog had it wrong, the filename will have the value the cutout server + # provided. + tract = filename.split("_")[4] + columns[dynamic_col].append(parse_tract_opt(tract)) + elif dynamic_col == "ra": + ra = obj_to_ra[object_id] + columns[dynamic_col].append(parse_longitude(ra)) + elif dynamic_col == "dec": + dec = obj_to_dec[object_id] + columns[dynamic_col].append(parse_latitude(dec)) + elif dynamic_col == "filename": + columns[dynamic_col].append(filename) + else: + # The tower of if statements has been entirely to create this failure path. + # which will be hit when someone alters dynamic column names above without also + # writing an implementation. + raise RuntimeError(f"No implementation to process column {dynamic_col}") + + logger.info("Writing rebuilt manifest...") + manifest_table = Table(columns) + rebuilt_manifest_path = Path(config["general"]["data_dir"]) / "rebuilt_manifest.fits" + manifest_table.write(rebuilt_manifest_path, overwrite=True, format="fits") + def shape(self) -> tuple[int, int, int]: """Shape of the individual cutouts this will give to a model @@ -641,6 +737,25 @@ def ids(self): for object_id in self.files: yield object_id + def _all_files_full(self): + """ + Private read-only iterator over all files that enforces a strict total order across + objects and filters. Will not work prior to self.files, and self.path initialization in __init__ + + Yields + ------ + Tuple[object_id, filter, filename, dim] + Members of this tuple are + - The object_id as a string + - The filter name as a string + - The filename relative to self.path + - A tuple containing the dimensions of the fits file in pixels. + """ + for object_id in self.ids(): + dims = self.dims[object_id] + for idx, (filter, filename) in enumerate(self._filter_filename(object_id)): + yield (object_id, filter, filename, dims[idx]) + def _all_files(self): """ Private read-only iterator over all files that enforces a strict total order across @@ -655,6 +770,22 @@ def _all_files(self): for filename in self._object_files(object_id): yield filename + def _filter_filename(self, object_id): + """ + Private read-only iterator over all files for a given object. This enforces a strict total order + across filters. Will not work prior to self.files initialization in __init__ + + Yields + ------ + filter_name, file name + The name of a filter and the file name for the fits file. + The file name is relative to self.path + """ + filters = self.files[object_id] + filter_names = sorted(list(filters)) + for filter_name in filter_names: + yield filter_name, filters[filter_name] + def _object_files(self, object_id): """ Private read-only iterator over all files for a given object. This enforces a strict total order @@ -665,10 +796,8 @@ def _object_files(self, object_id): Path The path to the file. """ - filters = self.files[object_id] - filter_names = sorted(list(filters)) - for filter in filter_names: - yield self._file_to_path(filters[filter]) + for _, filename in self._filter_filename(object_id): + yield self._file_to_path(filename) def _file_to_path(self, filename: str) -> Path: """Turns a filename into a full path suitable for open. Equivalent to: diff --git a/src/fibad/download.py b/src/fibad/download.py index 9e7e73e..8f78da6 100644 --- a/src/fibad/download.py +++ b/src/fibad/download.py @@ -28,6 +28,8 @@ class Downloader: # of the immutable fields that we rely on for hash checks are also included. RECT_COLUMN_NAMES = list(dict.fromkeys(VARIABLE_FIELDS + dC.Rect.immutable_fields + ["dim"])) + MANIFEST_COLUMN_NAMES = RECT_COLUMN_NAMES + ["filename", "object_id"] + MANIFEST_FILE_NAME = "manifest.fits" def __init__(self, config): @@ -280,9 +282,8 @@ def _write_manifest(self): logger.info(f"Writing out download manifest with {len(combined_manifest)} entries.") # Convert the combined manifest into an astropy table by building a dict of {column_name: column_data} - # for all the fields in a rect, plus our object_id and filename. - column_names = Downloader.RECT_COLUMN_NAMES + ["filename", "object_id"] - columns = {column_name: [] for column_name in column_names} + # for all the fields we require in a manifest + columns = {column_name: [] for column_name in Downloader.MANIFEST_COLUMN_NAMES} for rect, msg in combined_manifest.items(): # This parsing relies on the name format set up in create_rects to work properly diff --git a/src/fibad/fibad.py b/src/fibad/fibad.py index 474a005..4d8cd13 100644 --- a/src/fibad/fibad.py +++ b/src/fibad/fibad.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Union -from .config_utils import ConfigManager, resolve_runtime_config +from .config_utils import ConfigManager class Fibad: @@ -14,7 +14,7 @@ class Fibad: CLI functions in fibad_cli are implemented by calling this class """ - verbs = ["train", "predict", "download", "prepare"] + verbs = ["train", "predict", "download", "prepare", "rebuild_manifest"] def __init__(self, *, config_file: Union[Path, str] = None, setup_logging: bool = True): """Initialize fibad. Always applies the default config, and merges it with any provided config file. @@ -88,7 +88,7 @@ def __init__(self, *, config_file: Union[Path, str] = None, setup_logging: bool # Setup our handlers from config self._initialize_log_handlers() - self.logger.info(f"Runtime Config read from: {resolve_runtime_config(config_file)}") + self.logger.info(f"Runtime Config read from: {ConfigManager.resolve_runtime_config(config_file)}") def _initialize_log_handlers(self): """Private initialization helper, Adds handlers and level setting to the global self.logger object""" @@ -180,8 +180,16 @@ def predict(self, **kwargs): def prepare(self, **kwargs): """ - See Fibad.predict.run() + See Fibad.prepare.run() """ from .prepare import run return run(config=self.config, **kwargs) + + def rebuild_manifest(self, **kwargs): + """ + See Fibad.rebuild_manifest.run() + """ + from .rebuild_manifest import run + + return run(config=self.config, **kwargs)