diff --git a/example_notebooks/GettingStartedDownloader.ipynb b/example_notebooks/GettingStartedDownloader.ipynb index 9102080..0558b78 100644 --- a/example_notebooks/GettingStartedDownloader.ipynb +++ b/example_notebooks/GettingStartedDownloader.ipynb @@ -45,10 +45,39 @@ "metadata": {}, "outputs": [], "source": [ - "fibad_instance = fibad.Fibad(config_file=fibad_config)\n", + "import fibad\n", + "import os\n", + "from pathlib import Path\n", + "\n", + "# os.chdir(Path(fibad.__file__).parent/\"..\"/\"..\")\n", + "fibad_instance = fibad.Fibad(config_file=\"fibad_config.toml\")\n", "\n", "fibad_instance.download()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "widths, heights = fibad_instance.raw_data_dimensions()\n", + "\n", + "fig, axs = plt.subplots(1, 2)\n", + "fig.set_figwidth(12)\n", + "\n", + "_, _, _ = axs[0].hist(heights, range=(260, 270), bins=10)\n", + "_, _, _ = axs[1].hist(widths, range=(260, 270), bins=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/pyproject.toml b/pyproject.toml index b275e24..0529fb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ examples = [ dev = [ "asv==0.6.4", # Used to compute performance benchmarks "jupyter", # Clears output from Jupyter notebooks + "matplotlib", # For example notebooks "pre-commit", # Used to run checks before finalizing a git commit "pytest", "pytest-cov", # Used to report total code coverage diff --git a/src/fibad/download.py b/src/fibad/download.py index 11a3529..e25ef08 100644 --- a/src/fibad/download.py +++ b/src/fibad/download.py @@ -2,8 +2,9 @@ import logging from pathlib import Path from threading import Thread -from typing import Optional +from typing import Optional, Union +from astropy.io import fits from astropy.table import Table, hstack import fibad.downloadCutout.downloadCutout as dC @@ -21,12 +22,16 @@ class Downloader: VARIABLE_FIELDS = ["tract", "ra", "dec"] # These are the column names we retain when writing a rect out to the manifest.fits file - RECT_COLUMN_NAMES = VARIABLE_FIELDS + ["filter", "sw", "sh", "rerun", "type"] + RECT_COLUMN_NAMES = VARIABLE_FIELDS + ["filter", "sw", "sh", "rerun", "type", "dim"] MANIFEST_FILE_NAME = "manifest.fits" - @staticmethod - def run(config): + def __init__(self, config): + self.config = config.get("download", {}) + self.cutout_path = Path(self.config.get("cutout_dir")).resolve() + self.manifest_file = self.cutout_path / Downloader.MANIFEST_FILE_NAME + + def run(self): """ Main entrypoint for downloading cutouts from HSC for use with fibad @@ -36,42 +41,39 @@ def run(config): Runtime configuration as a nested dictionary """ - config = config.get("download", {}) - logger.info("Download command Start") - fits_file = Path(config.get("fits_file", "")).resolve() + fits_file = Path(self.config.get("fits_file", "")).resolve() logger.info(f"Reading in fits catalog: {fits_file}") # Filter the fits file for the fields we want column_names = ["object_id"] + Downloader.VARIABLE_FIELDS locations = Downloader.filterfits(fits_file, column_names) # If offet/length specified, filter to that length - offset = config.get("offset", 0) - end = offset + config.get("num_sources", None) + offset = self.config.get("offset", 0) + end = offset + self.config.get("num_sources", None) if end is not None: locations = locations[offset:end] - cutout_path = Path(config.get("cutout_dir")).resolve() - logger.info(f"Downloading cutouts to {cutout_path}") + logger.info(f"Downloading cutouts to {self.cutout_path}") logger.info("Making a list of cutouts...") # Make a list of rects to pass to downloadCutout - rects = Downloader.create_rects( - locations, offset=0, default=Downloader.rect_from_config(config), path=cutout_path + self.rects = Downloader.create_rects( + locations, offset=0, default=Downloader.rect_from_config(self.config), path=self.cutout_path ) logger.info("Checking the list against currently downloaded cutouts...") # Prune any previously downloaded rects from our list using the manifest from the previous download - rects = Downloader._prune_downloaded_rects(cutout_path, rects) + self.rects = self._prune_downloaded_rects() # Early return if there is nothing to download. - if len(rects) == 0: + if len(self.rects) == 0: logger.info("Download already complete according to manifest.") return # Create thread objects for each of our worker threads - num_threads = config.get("concurrent_connections", 2) + num_threads = self.config.get("concurrent_connections", 2) if num_threads > 5: raise RuntimeError("This client only opens 5 connections or fewer.") @@ -89,22 +91,26 @@ def _batched(iterable, n): yield batch logger.info("Dividing cutouts among threads...") - thread_rects = list(_batched(rects, int(len(rects) / num_threads))) if num_threads != 1 else [rects] + thread_rects = ( + list(_batched(self.rects, int(len(self.rects) / num_threads))) + if num_threads != 1 + else [self.rects] + ) # Empty dictionaries for the threads to create download manifests in - thread_manifests = [dict() for _ in range(num_threads)] + self.thread_manifests = [dict() for _ in range(num_threads)] shared_thread_args = ( - config["username"], - config["password"], - DownloadStats(print_interval_s=config.get("stats_print_interval", 60)), + self.config["username"], + self.config["password"], + DownloadStats(print_interval_s=self.config.get("stats_print_interval", 60)), ) shared_thread_kwargs = { - "retrywait": config.get("retry_wait", 30), - "retries": config.get("retries", 3), - "timeout": config.get("timeout", 3600), - "chunksize": config.get("chunk_size", 990), + "retrywait": self.config.get("retry_wait", 30), + "retries": self.config.get("retries", 3), + "timeout": self.config.get("timeout", 3600), + "chunksize": self.config.get("chunk_size", 990), } download_threads = [ @@ -114,7 +120,7 @@ def _batched(iterable, n): daemon=True, # daemon so these threads will die when the main thread is interrupted args=(thread_rects[i],) # rects + shared_thread_args # username, password, download stats - + (i, thread_manifests[i]), # thread_num, manifest + + (i, self.thread_manifests[i]), # thread_num, manifest kwargs=shared_thread_kwargs, ) for i in range(num_threads) @@ -125,12 +131,11 @@ def _batched(iterable, n): [thread.start() for thread in download_threads] [thread.join() for thread in download_threads] finally: # Ensure manifest is written even when we get a KeyboardInterrupt during download - Downloader.write_manifest(thread_manifests, cutout_path) + self._write_manifest() logger.info("Done") - @staticmethod - def _prune_downloaded_rects(cutout_path: Path, rects: list[dC.Rect]) -> list[dC.Rect]: + def _prune_downloaded_rects(self): """Prunes already downloaded rects using the manifest in `cutout_path`. `rects` passed in is mutated by this operation @@ -155,13 +160,13 @@ def _prune_downloaded_rects(cutout_path: Path, rects: list[dC.Rect]) -> list[dC. """ # print(rects) # Read in any prior manifest. - prior_manifest = Downloader.read_manifest(cutout_path) + prior_manifest = self.manifest_to_rects() # If we found a manifest, we are resuming a download if len(prior_manifest) != 0: # Filter rects to figure out which ones are completely downloaded. # This operation consumes prior_manifest in the process - rects[:] = [rect for rect in rects if Downloader._keep_rect(rect, prior_manifest)] + self.rects[:] = [rect for rect in self.rects if Downloader._keep_rect(rect, prior_manifest)] # if prior_manifest was not completely consumed, than the earlier download attempted # some sky locations which would not be included in the current download, and we have @@ -170,12 +175,12 @@ def _prune_downloaded_rects(cutout_path: Path, rects: list[dC.Rect]) -> list[dC. # print(len(prior_manifest)) # print (prior_manifest) raise RuntimeError( - f"""{cutout_path/Downloader.MANIFEST_FILE_NAME} describes a download with + f"""{self.manifest_file} describes a download with sky locations that would not be downloaded in the download currently being attempted. Are you sure you are resuming the correct download? Deleting the manifest and cutout files will start the download from scratch""" ) - return rects + return self.rects @staticmethod def _keep_rect(location_rect: dC.Rect, prior_manifest: dict[dC.Rect, str]) -> bool: @@ -217,8 +222,7 @@ def _keep_rect(location_rect: dC.Rect, prior_manifest: dict[dC.Rect, str]) -> bo return keep_rect - @staticmethod - def write_manifest(thread_manifests: list[dict[dC.Rect, str]], file_path: Path): + def _write_manifest(self): """Write out manifest fits file that is an inventory of the download. The manifest fits file should have columns object_id, ra, dec, tract, filter, filename @@ -246,24 +250,17 @@ def write_manifest(thread_manifests: list[dict[dC.Rect, str]], file_path: Path): sh: Semi-height of the cutout box in degrees rerun: The data release in use e.g. pdr3_wide type: coadd, warp, or other values allowed by the HSC docs + dim: Tuple of integers with the dimensions of the image. - Parameters - ---------- - thread_manifests : list[dict[dC.Rect,str]] - Manifests mapping rects -> Filename or status message. Each manifest came from a separate thread. - - file_path : Path - Full path to the location where the manifest file ought be written. The manifest file will be - named manifest.fits """ logger.info("Assembling download manifest") # Start building a combined manifest from all threads from the ground truth of the prior manifest # in this directory, which we will be overwriting. - combined_manifest = Downloader.read_manifest(file_path) + combined_manifest = self.manifest_to_rects() # Combine all thread manifests with the prior manifest, so that the current status of a downloaded # rect overwrites any status from the prior run (which is no longer relevant.) - for manifest in thread_manifests: + for manifest in self.thread_manifests: combined_manifest.update(manifest) logger.info(f"Writing out download manifest with {len(combined_manifest)} entries.") @@ -293,38 +290,50 @@ def write_manifest(thread_manifests: list[dict[dC.Rect, str]], file_path: Path): # print (key, len(val), val) manifest_table = Table(columns) - manifest_table.write(file_path / Downloader.MANIFEST_FILE_NAME, overwrite=True, format="fits") + manifest_table.write(self.manifest_file, overwrite=True, format="fits") logger.info("Finished writing download manifest") - @staticmethod - def read_manifest(file_path: Path) -> dict[dC.Rect, str]: + def get_manifest(self): + """Get the current downloader manifest, which is a list of files where download has been attempted + The format of the table is outlined in _write_manifest() + + Returns + ------- + astropy.table.Table + The entire download manifest + """ + if self.manifest_file.exists(): + return Table.read(self.manifest_file, format="fits") + + return None + + def manifest_to_rects(self) -> dict[dC.Rect, str]: """Read the manifest.fits file from the given directory and return its contents as a dictionary with downloadCutout.Rectangles as keys and filenames as values. If now manifest file is found, an empty dict is returned. - Parameters - ---------- - file_path : Path - Where to find the manifest file - Returns ------- dict[dC.Rect, str] A dictionary containing all the rects in the manifest and all the filenames, or empty dict if no manifest is found. """ - filename = file_path / Downloader.MANIFEST_FILE_NAME - if filename.exists(): - manifest_table = Table.read(filename, format="fits") + manifest_table = self.get_manifest() + if manifest_table is not None: rects = Downloader.create_rects( - locations=manifest_table, fields=Downloader.RECT_COLUMN_NAMES, path=file_path + locations=manifest_table, fields=Downloader.RECT_COLUMN_NAMES, path=self.cutout_path ) return {rect: filename for rect, filename in zip(rects, manifest_table["filename"])} else: return {} + @staticmethod + def _rect_hook(rect: dC.Rect, filename: Union[Path, str]): + with fits.open(filename) as hdul: + rect.dim = hdul[1].shape + @staticmethod def download_thread( rects: list[dC.Rect], @@ -365,6 +374,7 @@ def download_thread( password=password, onmemory=False, request_hook=stats_hook, + rect_hook=Downloader._rect_hook, manifest=manifest, **kwargs, ) @@ -459,7 +469,7 @@ def create_rects( rects = [] fields = fields if fields else Downloader.VARIABLE_FIELDS for index, location in enumerate(locations): - args = {field: location[field] for field in fields} + args = {field: location.get(field) for field in fields} args["lineno"] = index + offset args["tract"] = str(args["tract"]) # Sets the file name on the rect to be the object_id, also includes other rect fields diff --git a/src/fibad/downloadCutout/downloadCutout.py b/src/fibad/downloadCutout/downloadCutout.py index a0c1706..f6a5fd7 100644 --- a/src/fibad/downloadCutout/downloadCutout.py +++ b/src/fibad/downloadCutout/downloadCutout.py @@ -58,6 +58,8 @@ def export(obj): export("ANYTRACT") ANYTRACT = -1 +export("NODIM") +NODIM = (-1, -1) export("ALLFILTERS") ALLFILTERS = "all" @@ -324,6 +326,8 @@ class Rect: File name format (without extension ".fits") lineno Line number in a list file. + dim + Dimensions of downloaded file """ rerun: str = default_rerun @@ -339,6 +343,7 @@ class Rect: variance: bool = default_get_variance name: str = default_name lineno: int = 0 + dim: tuple[int, int] = NODIM @staticmethod def create( @@ -355,6 +360,7 @@ def create( variance: Union[str, bool, None] = None, name: Union[str, None] = None, lineno: Union[int, None] = None, + dim: Union[tuple[int, int], None] = None, default: Union["Rect", None] = None, ) -> "Rect": """ @@ -401,6 +407,8 @@ def create( File name format (without extension ".fits") lineno Line number in a list file. + dim + Dimensions of the image in pixels. default Default value. @@ -437,6 +445,8 @@ def create( rect.name = str(name) if lineno is not None: rect.lineno = int(lineno) + if dim is not None: + rect.dim = dim return rect @@ -1204,7 +1214,8 @@ def _download_chunk( onmemory: bool, request_hook: Optional[ Callable[[urllib.request.Request, datetime.datetime, datetime.datetime, int, int], Any] - ], + ] = None, + rect_hook: Optional[Callable[[Rect, str], Any]] = None, **kwargs_urlopen, ) -> Optional[list]: """ @@ -1231,6 +1242,9 @@ def _download_chunk( request_hook Function that is called with the response of all requests made Intended to support bandwidth instrumentation. + rect_hook + Function to be called on every rectangle downloaded. The callback recieves the rect and the filename + as arguments kwargs_urlopen Additional keyword args are passed through to urllib.request.urlopen @@ -1312,6 +1326,8 @@ def _download_chunk( os.makedirs(dirname, exist_ok=True) with open(filename, "wb") as fout: _splice(fitem, fout) + if rect_hook: + rect_hook(rect, filename) if manifest is not None: manifest[rect] = filename if request_hook: diff --git a/src/fibad/fibad.py b/src/fibad/fibad.py index 07389fd..1ed9adc 100644 --- a/src/fibad/fibad.py +++ b/src/fibad/fibad.py @@ -135,6 +135,23 @@ def _initialize_log_handlers(self): formatter = logging.Formatter("[%(asctime)s %(name)s:%(levelname)s] %(message)s") handler.setFormatter(formatter) + def raw_data_dimensions(self) -> tuple[list[int], list[int]]: + """Gives the dimensions of underlying data that forms input to the training, and inference + steps. This is the raw data that the data loader must normalize to the model + + Returns + ------- + tuple[list[int],list[int]] + widths and heights of all images available locally. + """ + from .download import Downloader + + downloader = Downloader(config=self.config) + manifest = downloader.get_manifest() + widths = [int(dim[0]) for dim in manifest["dim"]] + heights = [int(dim[1]) for dim in manifest["dim"]] + return widths, heights + def train(self, **kwargs): """ See Fibad.train.run() @@ -149,7 +166,8 @@ def download(self, **kwargs): """ from .download import Downloader - return Downloader.run(config=self.config, **kwargs) + downloader = Downloader(config=self.config) + return downloader.run(**kwargs) def predict(self, **kwargs): """