diff --git a/noxfile.py b/noxfile.py index 290a6b80b..7637a8314 100644 --- a/noxfile.py +++ b/noxfile.py @@ -50,6 +50,18 @@ def dev(session: nox.Session) -> None: session.run(python, "-m", "pip", "install", "-e", ".[dev,test,doc]", external=True) +@nox.session(python="3.10") +def lint(session): + """ + Run the linter. + """ + session.install(".[dev]") + # run isort first since black disagrees with it + session.run("isort", "./src") + session.run("black", "./src", "--line-length=79") + session.run("flake8", "./src", "--max-line-length", "120", "--exclude", "./src/crowsetta/_vendor") + + # ---- used by sessions that "clean up" data for tests def clean_dir(dir_path): """ diff --git a/pyproject.toml b/pyproject.toml index 338f07096..af6a6c801 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,8 +46,11 @@ dependencies = [ [project.optional-dependencies] dev = [ "twine >=3.3.0", - "black >=20.8b1", - "ipython >=7.0" + "black >=23.7.0", + "flake8 >=6.0.0", + "ipython >=7.0", + "isort >=5.12.0", + "pycln >=2.1.3", ] test = [ "pytest >=6.2.1", diff --git a/src/scripts/download_autoannotate_data.py b/src/scripts/download_autoannotate_data.py index 5d1d95155..03aaed609 100644 --- a/src/scripts/download_autoannotate_data.py +++ b/src/scripts/download_autoannotate_data.py @@ -4,6 +4,7 @@ https://github.com/NickleDave/bfsongrepo/blob/main/src/scripts/download_dataset.py """ from __future__ import annotations + import argparse import pathlib import shutil @@ -12,17 +13,16 @@ import urllib.request import warnings - DATA_TO_DOWNLOAD = { "gy6or6": { "sober.repo1.gy6or6.032212.wav.csv.tar.gz": { "MD5": "8c88b46ba87f9784d3690cc8ee4bf2f4", - "download": "https://figshare.com/ndownloader/files/37509160" + "download": "https://figshare.com/ndownloader/files/37509160", }, "sober.repo1.gy6or6.032312.wav.csv.tar.gz": { "MD5": "063ba4d50d1b94009b4b00f0a941d098", - "download": "https://figshare.com/ndownloader/files/37509172" - } + "download": "https://figshare.com/ndownloader/files/37509172", + }, } } @@ -39,44 +39,40 @@ def reporthook(count: int, block_size: int, total_size: int) -> None: progress_size = int(count * block_size) speed = int(progress_size / (1024 * duration)) percent = int(count * block_size * 100 / total_size) - sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % - (percent, progress_size / (1024 * 1024), speed, duration)) + sys.stdout.write( + "\r...%d%%, %d MB, %d KB/s, %d seconds passed" + % (percent, progress_size / (1024 * 1024), speed, duration) + ) sys.stdout.flush() -def download_dataset(download_urls_by_bird_ID: dict, - bfsongrepo_dir: pathlib.Path) -> None: +def download_dataset( + download_urls_by_bird_ID: dict, bfsongrepo_dir: pathlib.Path +) -> None: """download the dataset, given a dict of download urls""" tar_dir = bfsongrepo_dir / "tars" tar_dir.mkdir() # top-level keys are bird ID: bl26lb16, gr41rd51, ... for bird_id, tars_dict in download_urls_by_bird_ID.items(): - print( - f'Downloading .tar files for bird: {bird_id}' - ) - # bird ID -> dict where keys are .tar.gz filenames mapping to download url + MD5 hash + print(f"Downloading .tar files for bird: {bird_id}") + # bird ID -> dict + # where keys are .tar.gz filenames mapping to download url + MD5 hash for tar_name, url_md5_dict in tars_dict.items(): - print( - f'Downloading tar: {tar_name}' - ) - download_url = url_md5_dict['download'] + print(f"Downloading tar: {tar_name}") + download_url = url_md5_dict["download"] filename = tar_dir / tar_name urllib.request.urlretrieve(download_url, filename, reporthook) - print('\n') + print("\n") def extract_tars(bfsongrepo_dir: pathlib.Path) -> None: tar_dir = bfsongrepo_dir / "tars" # made by download_dataset function - tars = sorted(tar_dir.glob('*.tar.gz')) + tars = sorted(tar_dir.glob("*.tar.gz")) for tar_path in tars: - print( - f"\nunpacking: {tar_path}" - ) + print(f"\nunpacking: {tar_path}") shutil.unpack_archive( - filename=tar_path, - extract_dir=bfsongrepo_dir, - format="gztar" + filename=tar_path, extract_dir=bfsongrepo_dir, format="gztar" ) @@ -87,7 +83,7 @@ def main(dst: str | pathlib.Path) -> None: raise NotADirectoryError( f"Value for 'dst' argument not recognized as a directory: {dst}" ) - bfsongrepo_dir = dst / 'bfsongrepo' + bfsongrepo_dir = dst / "bfsongrepo" if bfsongrepo_dir.exists(): warnings.warn( f"Directory already exists: {bfsongrepo_dir}\n" @@ -103,9 +99,7 @@ def main(dst: str | pathlib.Path) -> None: "If that fails, please download files for tutorial manually from the 'download' links in tutorial page." ) from e - print( - f'Downloading Bengalese Finch Song Repository to: {bfsongrepo_dir}' - ) + print(f"Downloading Bengalese Finch Song Repository to: {bfsongrepo_dir}") download_dataset(DATA_TO_DOWNLOAD, bfsongrepo_dir) extract_tars(bfsongrepo_dir) @@ -115,11 +109,13 @@ def get_parser() -> argparse.ArgumentParser: """get ArgumentParser used to parse command-line arguments""" parser = argparse.ArgumentParser() parser.add_argument( - '--dst', - default='.', - help=("Destination where dataset should be downloaded. " - "Default is '.', i.e., current working directory " - "from which this script is run.'") + "--dst", + default=".", + help=( + "Destination where dataset should be downloaded. " + "Default is '.', i.e., current working directory " + "from which this script is run.'" + ), ) return parser diff --git a/src/vak/__about__.py b/src/vak/__about__.py index 60750bbb4..fac64e073 100644 --- a/src/vak/__about__.py +++ b/src/vak/__about__.py @@ -20,7 +20,9 @@ __title__ = "vak" -__summary__ = "a neural network toolbox for animal vocalizations and bioacoustics" +__summary__ = ( + "a neural network toolbox for animal vocalizations and bioacoustics" +) __uri__ = "https://github.com/NickleDave/vak" __version__ = "1.0.0a1" diff --git a/src/vak/__init__.py b/src/vak/__init__.py index 20aad5e67..f6bbb4eed 100644 --- a/src/vak/__init__.py +++ b/src/vak/__init__.py @@ -1,15 +1,3 @@ -from .__about__ import ( - __author__, - __commit__, - __copyright__, - __email__, - __license__, - __summary__, - __title__, - __uri__, - __version__, -) - from . import ( __main__, cli, @@ -28,17 +16,38 @@ train, transforms, ) - +from .__about__ import ( + __author__, + __commit__, + __copyright__, + __email__, + __license__, + __summary__, + __title__, + __uri__, + __version__, +) __all__ = [ "__main__", + "__author__", + "__commit__", + "__copyright__", + "__email__", + "__license__", + "__summary__", + "__title__", + "__uri__", + "__version__", "cli", "common", "config", "datasets", + "eval", "learncurve", "metrics", "models", + "nets", "nn", "plot", "predict", diff --git a/src/vak/__main__.py b/src/vak/__main__.py index 955768fb1..c3f6c0bac 100644 --- a/src/vak/__main__.py +++ b/src/vak/__main__.py @@ -11,7 +11,7 @@ def get_parser(): """returns ArgumentParser instance used by main()""" parser = argparse.ArgumentParser( - prog='vak', + prog="vak", description="vak command-line interface", formatter_class=argparse.RawTextHelpFormatter, ) diff --git a/src/vak/cli/__init__.py b/src/vak/cli/__init__.py index 19722ee96..cb0cc02aa 100644 --- a/src/vak/cli/__init__.py +++ b/src/vak/cli/__init__.py @@ -3,7 +3,6 @@ from . import cli, eval, learncurve, predict, prep, train - __all__ = [ "cli", "eval", diff --git a/src/vak/cli/cli.py b/src/vak/cli/cli.py index f10f0b7b1..d6d2eaca3 100644 --- a/src/vak/cli/cli.py +++ b/src/vak/cli/cli.py @@ -1,25 +1,30 @@ def eval(toml_path): from .eval import eval + eval(toml_path=toml_path) def train(toml_path): from .train import train + train(toml_path=toml_path) def learncurve(toml_path): from .learncurve import learning_curve + learning_curve(toml_path=toml_path) def predict(toml_path): from .predict import predict + predict(toml_path=toml_path) def prep(toml_path): from .prep import prep + prep(toml_path=toml_path) diff --git a/src/vak/cli/eval.py b/src/vak/cli/eval.py index 04dd08640..329bf38b0 100644 --- a/src/vak/cli/eval.py +++ b/src/vak/cli/eval.py @@ -5,7 +5,6 @@ from .. import eval as eval_module from ..common.logging import config_logging_for_cli, log_version - logger = logging.getLogger(__name__) @@ -32,10 +31,7 @@ def eval(toml_path): # ---- set up logging --------------------------------------------------------------------------------------------- config_logging_for_cli( - log_dst=cfg.eval.output_dir, - log_stem="eval", - level="INFO", - force=True + log_dst=cfg.eval.output_dir, log_stem="eval", level="INFO", force=True ) log_version(logger) diff --git a/src/vak/cli/learncurve.py b/src/vak/cli/learncurve.py index 5bf302f6c..ff9869014 100644 --- a/src/vak/cli/learncurve.py +++ b/src/vak/cli/learncurve.py @@ -1,12 +1,11 @@ import logging -from pathlib import Path import shutil +from pathlib import Path from .. import config, learncurve from ..common.logging import config_logging_for_cli, log_version from ..common.paths import generate_results_dir_name_as_path - logger = logging.getLogger(__name__) @@ -32,17 +31,16 @@ def learning_curve(toml_path): ) # ---- set up directory to save output ----------------------------------------------------------------------------- - results_path = generate_results_dir_name_as_path(cfg.learncurve.root_results_dir) + results_path = generate_results_dir_name_as_path( + cfg.learncurve.root_results_dir + ) results_path.mkdir(parents=True) # copy config file into results dir now that we've made the dir shutil.copy(toml_path, results_path) # ---- set up logging ---------------------------------------------------------------------------------------------- config_logging_for_cli( - log_dst=results_path, - log_stem="learncurve", - level="INFO", - force=True + log_dst=results_path, log_stem="learncurve", level="INFO", force=True ) log_version(logger) logger.info("Logging results to {}".format(results_path)) diff --git a/src/vak/cli/predict.py b/src/vak/cli/predict.py index 7393aab26..38701b87f 100644 --- a/src/vak/cli/predict.py +++ b/src/vak/cli/predict.py @@ -5,7 +5,6 @@ from .. import predict as predict_module from ..common.logging import config_logging_for_cli, log_version - logger = logging.getLogger(__name__) @@ -31,7 +30,7 @@ def predict(toml_path): log_dst=cfg.predict.output_dir, log_stem="predict", level="INFO", - force=True + force=True, ) log_version(logger) logger.info("Logging results to {}".format(cfg.prep.output_dir)) diff --git a/src/vak/cli/prep.py b/src/vak/cli/prep.py index da843d6ff..13636a4de 100644 --- a/src/vak/cli/prep.py +++ b/src/vak/cli/prep.py @@ -1,8 +1,8 @@ # note NO LOGGING -- we configure logger inside `core.prep` # so we can save log file inside dataset directory -from pathlib import Path import shutil import warnings +from pathlib import Path import toml @@ -92,7 +92,9 @@ def prep(toml_path): ) # now that we've checked that, go ahead and parse the sections we want - cfg = config.parse.from_toml_path(toml_path, sections=SECTIONS_PREP_SHOULD_PARSE) + cfg = config.parse.from_toml_path( + toml_path, sections=SECTIONS_PREP_SHOULD_PARSE + ) # notice we ignore any other option/values in the 'purpose' section, # see https://github.com/NickleDave/vak/issues/334 and https://github.com/NickleDave/vak/issues/314 if cfg.prep is None: diff --git a/src/vak/cli/train.py b/src/vak/cli/train.py index 9f11b1a95..91c89bb95 100644 --- a/src/vak/cli/train.py +++ b/src/vak/cli/train.py @@ -1,13 +1,12 @@ import logging -from pathlib import Path import shutil +from pathlib import Path from .. import config from .. import train as train_module from ..common.logging import config_logging_for_cli, log_version from ..common.paths import generate_results_dir_name_as_path - logger = logging.getLogger(__name__) @@ -32,17 +31,16 @@ def train(toml_path): ) # ---- set up directory to save output ----------------------------------------------------------------------------- - results_path = generate_results_dir_name_as_path(cfg.train.root_results_dir) + results_path = generate_results_dir_name_as_path( + cfg.train.root_results_dir + ) results_path.mkdir(parents=True) # copy config file into results dir now that we've made the dir shutil.copy(toml_path, results_path) # ---- set up logging ---------------------------------------------------------------------------------------------- config_logging_for_cli( - log_dst=results_path, - log_stem="train", - level="INFO", - force=True + log_dst=results_path, log_stem="train", level="INFO", force=True ) log_version(logger) logger.info("Logging results to {}".format(results_path)) diff --git a/src/vak/common/__init__.py b/src/vak/common/__init__.py index dfd57accc..e453adbb6 100644 --- a/src/vak/common/__init__.py +++ b/src/vak/common/__init__.py @@ -25,7 +25,6 @@ validators, ) - __all__ = [ "annotation", "constants", diff --git a/src/vak/common/annotation.py b/src/vak/common/annotation.py index e0c279321..0cb6cad2c 100644 --- a/src/vak/common/annotation.py +++ b/src/vak/common/annotation.py @@ -1,16 +1,16 @@ from __future__ import annotations -from collections import Counter + import copy import os import pathlib +from collections import Counter from typing import Optional, Union import crowsetta import numpy as np import pandas as pd -from . import files -from . import constants +from . import constants, files from .typing import PathLike @@ -36,7 +36,10 @@ def format_from_df(dataset_df: pd.DataFrame) -> str: annot_format = dataset_df["annot_format"].unique() if len(annot_format) == 1: annot_format = annot_format.item() - if annot_format is None or annot_format == constants.NO_ANNOTATION_FORMAT: + if ( + annot_format is None + or annot_format == constants.NO_ANNOTATION_FORMAT + ): return None elif len(annot_format) > 1: raise ValueError( @@ -46,8 +49,9 @@ def format_from_df(dataset_df: pd.DataFrame) -> str: return annot_format -def from_df(dataset_df: pd.DataFrame, - annot_root: str | pathlib.Path | None = None) -> list[crowsetta.Annotation] | None: +def from_df( + dataset_df: pd.DataFrame, annot_root: str | pathlib.Path | None = None +) -> list[crowsetta.Annotation] | None: """Get list of annotations from a dataframe representing a dataset. @@ -111,7 +115,9 @@ def from_df(dataset_df: pd.DataFrame, annots = scribe.from_file(annot_path).to_annot() # as long as we have at least as many annotations as there are rows in the dataframe - if (isinstance(annots, list) and len(annots) >= len(dataset_df)) or ( # case 1 + if ( + isinstance(annots, list) and len(annots) >= len(dataset_df) + ) or ( # case 1 isinstance(annots, crowsetta.Annotation) and len(dataset_df) == 1 ): # case 2 if isinstance(annots, crowsetta.Annotation): @@ -119,7 +125,9 @@ def from_df(dataset_df: pd.DataFrame, annots ] # wrap in list for map_annotated_to_annot to iterate over it # then we can try and map those annotations to the rows - audio_annot_map = map_annotated_to_annot(dataset_df["audio_path"].values, annots, annot_format) + audio_annot_map = map_annotated_to_annot( + dataset_df["audio_path"].values, annots, annot_format + ) # sort by row of dataframe annots = [ audio_annot_map[audio_path] @@ -138,9 +146,12 @@ def from_df(dataset_df: pd.DataFrame, # --> there is a unique annotation file (path) for each row, iterate over them to get labels from each annot_paths = dataset_df["annot_path"].values if annot_root: - annot_paths = [annot_root / annot_path for annot_path in annot_paths] + annot_paths = [ + annot_root / annot_path for annot_path in annot_paths + ] annots = [ - scribe.from_file(annot_path).to_annot() for annot_path in annot_paths + scribe.from_file(annot_path).to_annot() + for annot_path in annot_paths ] else: @@ -175,7 +186,10 @@ def files_from_dir(annot_dir, annot_format): elif isinstance(format_class.ext, tuple): # then we actually have to determine whether there's any files for either format for ext_to_test in format_class.ext: - if len(sorted(pathlib.Path(annot_dir).glob(f'*{ext_to_test}'))) > 0: + if ( + len(sorted(pathlib.Path(annot_dir).glob(f"*{ext_to_test}"))) + > 0 + ): ext = ext_to_test if ext is None: raise ValueError( @@ -197,8 +211,7 @@ class AudioFilenameNotFoundError(Exception): """ -def audio_filename_from_path(path: PathLike, - audio_ext: str = None) -> str: +def audio_filename_from_path(path: PathLike, audio_ext: str = None) -> str: """Find the name of an audio file within a filename by removing extensions until finding an audio extension, then return the name of that audio file @@ -243,12 +256,12 @@ def audio_filename_from_path(path: PathLike, Part of filename that precedes audio extension. """ if audio_ext: - if audio_ext.startswith('.'): + if audio_ext.startswith("."): audio_ext = audio_ext[1:] if audio_ext not in constants.VALID_AUDIO_FORMATS: raise ValueError( - f'Not a valid extension for audio formats: {audio_ext}\n' - f'Valid formats are: {constants.VALID_AUDIO_FORMATS}' + f"Not a valid extension for audio formats: {audio_ext}\n" + f"Valid formats are: {constants.VALID_AUDIO_FORMATS}" ) extensions_to_look_for = [audio_ext] else: @@ -274,12 +287,15 @@ class MapUsingNotatedPathError(BaseException): """Error raised when :func:`vak.annotation._map_using_notated_path` cannot map the filename of an annotation file to the name of an annotated file""" + pass -def _map_using_notated_path(annotated_files: list[PathLike], - annot_list: list[crowsetta.Annotation], - audio_ext: Optional[str] = None) -> dict[str: crowsetta.Annotation]: +def _map_using_notated_path( + annotated_files: list[PathLike], + annot_list: list[crowsetta.Annotation], + audio_ext: Optional[str] = None, +) -> dict: """Map a :class:`list` of annotated files to a :class:`list` of :class:`crowsetta.Annotation` instances, using the ``notated_path`` attribute of the @@ -333,7 +349,9 @@ def _map_using_notated_path(annotated_files: list[PathLike], keys_set = set(keys) if len(keys_set) < len(keys): - duplicates = [item for item, count in Counter(keys).items() if count > 1] + duplicates = [ + item for item, count in Counter(keys).items() if count > 1 + ] raise ValueError( f"found multiple annotations with the same audio filename(s): {duplicates}" ) @@ -345,7 +363,8 @@ def _map_using_notated_path(annotated_files: list[PathLike], audio_filename_annot_map = { # NOTE HERE WE GET FILENAMES FROM EACH annot.notated_path, # BELOW we get filenames from each annotated_file - audio_filename_from_path(annot.notated_path): annot for annot in annot_list + audio_filename_from_path(annot.notated_path): annot + for annot in annot_list } # Make a copy of ``annotated_files`` from which @@ -360,11 +379,15 @@ def _map_using_notated_path(annotated_files: list[PathLike], # that match with stems from each annot.notated_path; # e.g. find '~/path/to/llb3/llb3_0003_2018_04_23_14_18_54.wav.mat' that # should match with ``Annotation(notated_path='llb3_0003_2018_04_23_14_18_54.wav')`` - audio_filename_from_annotated_file = audio_filename_from_path(annotated_file) + audio_filename_from_annotated_file = audio_filename_from_path( + annotated_file + ) try: - annot = audio_filename_annot_map[audio_filename_from_annotated_file] + annot = audio_filename_annot_map[ + audio_filename_from_annotated_file + ] except KeyError as e: - raise MapUsingNotatedPathError ( + raise MapUsingNotatedPathError( "Could not map an annotation to an annotated file path " "using `vak.annotation.audio_filename_from_path` to get " "an audio filename from the annotated file path." @@ -390,14 +413,17 @@ class MapUsingExtensionError(BaseException): """Error raised when :func:`vak.annotation._map_using_ext` cannot map the filename of an annotation file to the name of an annotated file""" + pass -def _map_using_ext(annotated_files: list[PathLike], - annot_list: list[crowsetta.Annotation], - annot_format: str, - method: str, - annotated_ext: str | None = None) -> dict[str: crowsetta.Annotation]: +def _map_using_ext( + annotated_files: list[PathLike], + annot_list: list[crowsetta.Annotation], + annot_format: str, + method: str, + annotated_ext: str | None = None, +) -> dict: """Map a list of annotated files to a :class:`list` of :class:`crowsetta.Annotation` instances, by either removing the extension of the annotation format, @@ -440,7 +466,7 @@ def _map_using_ext(annotated_files: list[PathLike], Where each key is path to annotated file, and its value is the corresponding ``crowsetta.Annotation``. """ - if method not in {'remove', 'replace'}: + if method not in {"remove", "replace"}: raise ValueError( f"`method` must be one of: {{'remove', 'replace'}}, but was: '{method}'" ) @@ -449,14 +475,16 @@ def _map_using_ext(annotated_files: list[PathLike], pathlib.Path(annotated_file) for annotated_file in annotated_files ] - if method == 'replace': + if method == "replace": if annotated_ext is None: - annotated_ext_set = set([annotated_file.suffix for annotated_file in annotated_files]) + annotated_ext_set = set( + [annotated_file.suffix for annotated_file in annotated_files] + ) if len(annotated_ext_set) > 1: raise ValueError( "Found more than one extension in annotated files, " "unclear which extension to use when mapping to annotations " - f"with 'replace' method. Extensions found: {ext_set}" + f"with 'replace' method. Extensions found: {annotated_ext_set}" ) annotated_ext = annotated_ext_set.pop() @@ -477,13 +505,13 @@ def _map_using_ext(annotated_files: list[PathLike], # NOTE that by convention the `ext` attribute # of all Crowsetta annotation format classes # begins with a period - annotated_name = annot.annot_path.name.replace(annot_class.ext, '') + annotated_name = annot.annot_path.name.replace(annot_class.ext, "") elif isinstance(annot_class.ext, tuple): # handle the case where an annotation format can have multiple extensions, # e.g., ``Format.ext == ('.csv', '.txt')`` for ext in annot_class.ext: if annot.annot_path.name.endswith(ext): - annotated_name = annot.annot_path.name.replace(ext, '') + annotated_name = annot.annot_path.name.replace(ext, "") break if annotated_name is None: @@ -496,7 +524,7 @@ def _map_using_ext(annotated_files: list[PathLike], # NOTE we don't have to do anything else for method=='remove' # since we just removed the extension - if method == 'replace': + if method == "replace": annotated_name = annotated_name + annotated_ext annotated_filename_annot_map[annotated_name] = annot @@ -528,10 +556,12 @@ def _map_using_ext(annotated_files: list[PathLike], return {str(path): annot for path, annot in annotated_annot_map.items()} -def map_annotated_to_annot(annotated_files: Union[list, np.array], - annot_list: list[crowsetta.Annotation], - annot_format: str, - annotated_ext: str | None = None) -> dict[pathlib.Path : crowsetta.Annotation]: +def map_annotated_to_annot( + annotated_files: Union[list, np.array], + annot_list: list[crowsetta.Annotation], + annot_format: str, + annotated_ext: str | None = None, +) -> dict: """Map annotated files, i.e. audio or spectrogram files, to their corresponding annotations. @@ -595,31 +625,43 @@ def map_annotated_to_annot(annotated_files: Union[list, np.array], reference section of the documentation: https://vak.readthedocs.io/en/latest/reference/filenames.html """ - if type(annotated_files) == np.ndarray: # e.g., vak DataFrame['spect_path'].values + if isinstance(annotated_files, np.ndarray): # e.g., vak DataFrame['spect_path'].values annotated_files = annotated_files.tolist() - if annot_format in ('birdsong-recognition-dataset', 'yarden', 'generic-seq'): - annotated_annot_map = _map_using_notated_path(annotated_files, annot_list) + if annot_format in ( + "birdsong-recognition-dataset", + "yarden", + "generic-seq", + ): + annotated_annot_map = _map_using_notated_path( + annotated_files, annot_list + ) else: try: - annotated_annot_map = _map_using_ext(annotated_files, annot_list, annot_format, method='remove') + annotated_annot_map = _map_using_ext( + annotated_files, annot_list, annot_format, method="remove" + ) except MapUsingExtensionError: try: - annotated_annot_map = _map_using_ext(annotated_files, annot_list, annot_format, method='replace', - annotated_ext=annotated_ext) + annotated_annot_map = _map_using_ext( + annotated_files, + annot_list, + annot_format, + method="replace", + annotated_ext=annotated_ext, + ) except MapUsingExtensionError as e: raise ValueError( - 'Could not map annotated files to annotations.\n' - 'Please see this section in the `vak` documentation:\n' - 'https://vak.readthedocs.io/en/latest/howto/howto_prep_annotate.html' - '#how-does-vak-know-which-annotations-go-with-which-annotated-files' + "Could not map annotated files to annotations.\n" + "Please see this section in the `vak` documentation:\n" + "https://vak.readthedocs.io/en/latest/howto/howto_prep_annotate.html" + "#how-does-vak-know-which-annotations-go-with-which-annotated-files" ) from e return annotated_annot_map -def has_unlabeled(annot: crowsetta.Annotation, - duration: float) -> bool: +def has_unlabeled(annot: crowsetta.Annotation, duration: float) -> bool: """Returns ``True`` if an annotated sequence has unlabeled segments. Tests whether an instance of ``crowsetta.Annotation.seq`` has @@ -653,7 +695,13 @@ def has_unlabeled(annot: crowsetta.Annotation, # Handle edge case where there are no annotated segments in annotation file # See https://github.com/vocalpy/vak/issues/378 return True - has_unlabeled_intervals = np.any((annot.seq.onsets_s[1:] - annot.seq.offsets_s[:-1]) > 0.) - has_unlabeled_before_first_onset = annot.seq.onsets_s[0] > 0. - has_unlabeled_after_last_offset = duration - annot.seq.offsets_s[-1] > 0. - return has_unlabeled_intervals or has_unlabeled_before_first_onset or has_unlabeled_after_last_offset + has_unlabeled_intervals = np.any( + (annot.seq.onsets_s[1:] - annot.seq.offsets_s[:-1]) > 0.0 + ) + has_unlabeled_before_first_onset = annot.seq.onsets_s[0] > 0.0 + has_unlabeled_after_last_offset = duration - annot.seq.offsets_s[-1] > 0.0 + return ( + has_unlabeled_intervals + or has_unlabeled_before_first_onset + or has_unlabeled_after_last_offset + ) diff --git a/src/vak/common/constants.py b/src/vak/common/constants.py index 548b89149..e8aaad94e 100644 --- a/src/vak/common/constants.py +++ b/src/vak/common/constants.py @@ -5,10 +5,9 @@ import crowsetta import numpy as np +import soundfile from evfuncs import load_cbin from scipy.io import loadmat -import soundfile - # ---- audio files ---- AUDIO_FORMAT_FUNC_MAP = { @@ -28,7 +27,7 @@ VALID_SPECT_FORMATS = list(SPECT_FORMAT_LOAD_FUNCTION_MAP.keys()) # ---- valid types of training data, the $x$ that goes into a network -VALID_X_SOURCES = {'audio', 'spect'} +VALID_X_SOURCES = {"audio", "spect"} # ---- annotation files ---- VALID_ANNOT_FORMATS = crowsetta.formats.as_list() diff --git a/src/vak/common/converters.py b/src/vak/common/converters.py index c00d5e69d..6b349e182 100644 --- a/src/vak/common/converters.py +++ b/src/vak/common/converters.py @@ -1,11 +1,11 @@ -from pathlib import Path from distutils.util import strtobool +from pathlib import Path def bool_from_str(value): - if type(value) == bool: + if isinstance(value, bool): return value - elif type(value) == str: + elif isinstance(value, str): return bool(strtobool(value)) @@ -42,17 +42,21 @@ def range_str(range_str, sort=True): """ # adapted from # http://code.activestate.com/recipes/577279-generate-list-of-numbers-from-hyphenated-and-comma/ - s = "".join(range_str.split()) # removes white space + _ = "".join(range_str.split()) # removes white space list_range = [] for substr in range_str.split(","): subrange = substr.split("-") if len(subrange) not in [1, 2]: raise SyntaxError( - "unable to parse range {} in labelset {}.".format(subrange, substr) + "unable to parse range {} in labelset {}.".format( + subrange, substr + ) ) list_range.extend([int(subrange[0])]) if len( subrange - ) == 1 else list_range.extend(range(int(subrange[0]), int(subrange[1]) + 1)) + ) == 1 else list_range.extend( + range(int(subrange[0]), int(subrange[1]) + 1) + ) if sort: list_range.sort() diff --git a/src/vak/common/files/__init__.py b/src/vak/common/files/__init__.py index 020debe09..1a8c0e3aa 100644 --- a/src/vak/common/files/__init__.py +++ b/src/vak/common/files/__init__.py @@ -1,2 +1,9 @@ -from .files import find_fname, from_dir from . import spect +from .files import find_fname, from_dir + + +__all__ = [ + "find_fname", + "from_dir", + "spect", +] diff --git a/src/vak/common/files/files.py b/src/vak/common/files/files.py index 4de072076..5e86cd7fa 100644 --- a/src/vak/common/files/files.py +++ b/src/vak/common/files/files.py @@ -26,9 +26,9 @@ def find_fname(fname: str, ext: str) -> str | None: >>> vak.files.find_fname(fname='llb3_0003_2018_04_23_14_18_54.wav.mat', ext='wav') 'llb3_0003_2018_04_23_14_18_54.wav' """ - if ext.startswith('.'): + if ext.startswith("."): ext = ext[1:] - m = re.match(f"[\S ]*{ext}", fname) + m = re.match(f"[\S ]*{ext}", fname) # noqa: W605 if hasattr(m, "group"): return m.group() elif m is None: @@ -71,9 +71,11 @@ def from_dir(dir_path: str | pathlib.Path, ext: str) -> list[str]: """ dir_path = pathlib.Path(dir_path) if not dir_path.is_dir(): - raise NotADirectoryError(f"dir_path not recognized as a directory: {dir_path}") + raise NotADirectoryError( + f"dir_path not recognized as a directory: {dir_path}" + ) - if ext.startswith('.'): + if ext.startswith("."): ext = ext[1:] # use fnmatch + re to make search case-insensitive @@ -84,7 +86,9 @@ def from_dir(dir_path: str | pathlib.Path, ext: str) -> list[str]: rule = re.compile(fnmatch.translate(glob_pat), re.IGNORECASE) files = [ - file for file in dir_path.iterdir() if file.is_file() and rule.match(file.name) + file + for file in dir_path.iterdir() + if file.is_file() and rule.match(file.name) ] if len(files) == 0: diff --git a/src/vak/common/files/spect.py b/src/vak/common/files/spect.py index bb90116a9..e7fd485b1 100644 --- a/src/vak/common/files/spect.py +++ b/src/vak/common/files/spect.py @@ -1,4 +1,5 @@ from __future__ import annotations + import logging import pathlib @@ -7,15 +8,15 @@ from dask.diagnostics import ProgressBar from .. import constants -from .files import find_fname from ..timebins import timebin_dur_from_vec - +from .files import find_fname logger = logging.getLogger(__name__) -def find_audio_fname(spect_path: str | pathlib.Path, - audio_ext: str | None = None): +def find_audio_fname( + spect_path: str | pathlib.Path, audio_ext: str | None = None +): """finds name of audio file in a path to a spectrogram file, if one is present. @@ -63,8 +64,7 @@ def find_audio_fname(spect_path: str | pathlib.Path, ) -def load(spect_path: str | pathlib.Path, - spect_format: str | None = None): +def load(spect_path: str | pathlib.Path, spect_format: str | None = None): """load spectrogram and related arrays from a file, return as an object that provides Python dictionary-like access @@ -89,14 +89,18 @@ def load(spect_path: str | pathlib.Path, if spect_format is None: # "replace('.', '')", because suffix returns file extension with period included spect_format = spect_path.suffix.replace(".", "") - spect_dict = constants.SPECT_FORMAT_LOAD_FUNCTION_MAP[spect_format](spect_path) + spect_dict = constants.SPECT_FORMAT_LOAD_FUNCTION_MAP[spect_format]( + spect_path + ) return spect_dict -def timebin_dur(spect_path: str | pathlib.Path, - spect_format: str, - timebins_key: str = 't', - n_decimals_trunc: int = 5): +def timebin_dur( + spect_path: str | pathlib.Path, + spect_format: str, + timebins_key: str = "t", + n_decimals_trunc: int = 5, +): """get duration of time bins from a spectrogram file Parameters @@ -170,7 +174,8 @@ def is_valid_set_of_spect_files( def _validate(spect_path): """validates each spectrogram file, then returns frequency bin array - and duration of time bins, so that those can be validated across all files""" + and duration of time bins, so that those can be validated across all files + """ spect_dict = load(spect_path, spect_format) if spect_key not in spect_dict: @@ -184,13 +189,19 @@ def _validate(spect_path): timebin_dur = timebin_dur_from_vec(time_bins, n_decimals_trunc) # number of freq. bins should equal number of rows - if spect_dict[freqbins_key].shape[-1] != spect_dict[spect_key].shape[0]: + if ( + spect_dict[freqbins_key].shape[-1] + != spect_dict[spect_key].shape[0] + ): raise ValueError( f"length of frequency bins in {spect_path.name} " "does not match number of rows in spectrogram" ) # number of time bins should equal number of columns - if spect_dict[timebins_key].shape[-1] != spect_dict[spect_key].shape[1]: + if ( + spect_dict[timebins_key].shape[-1] + != spect_dict[spect_key].shape[1] + ): raise ValueError( f"length of time_bins in {spect_path.name} " f"does not match number of columns in spectrogram" @@ -205,7 +216,9 @@ def _validate(spect_path): with ProgressBar(): path_freqbins_timebin_dur_tups = list(spect_paths_bag.map(_validate)) - all_freq_bins = np.stack([tup[1] for tup in path_freqbins_timebin_dur_tups]) + all_freq_bins = np.stack( + [tup[1] for tup in path_freqbins_timebin_dur_tups] + ) uniq_freq_bins = np.unique(all_freq_bins, axis=0) if len(uniq_freq_bins) != 1: raise ValueError( diff --git a/src/vak/common/labels.py b/src/vak/common/labels.py index df0158aff..c747e1569 100644 --- a/src/vak/common/labels.py +++ b/src/vak/common/labels.py @@ -8,8 +8,7 @@ from . import annotation -def to_map(labelset: set, - map_unlabeled: bool = True) -> dict: +def to_map(labelset: set, map_unlabeled: bool = True) -> dict: """Convert set of labels to `dict` mapping those labels to a series of consecutive integers from 0 to n inclusive, @@ -40,8 +39,10 @@ def to_map(labelset: set, labelmap : dict Maps labels to integers. """ - if type(labelset) != set: - raise TypeError(f"type of labelset must be set, got type {type(labelset)}") + if not isinstance(labelset, set): + raise TypeError( + f"type of labelset must be set, got type {type(labelset)}" + ) labellist = [] if map_unlabeled is True: @@ -81,7 +82,9 @@ def to_set(labels_list: list[np.ndarray | list]) -> set: return labelset -def from_df(dataset_df: pd.DataFrame, dataset_path: str | pathlib.Path) -> list[np.ndarray]: +def from_df( + dataset_df: pd.DataFrame, dataset_path: str | pathlib.Path +) -> list[np.ndarray]: """Returns labels for each vocalization in a dataset. Takes Pandas DataFrame representing the dataset, loads @@ -108,21 +111,21 @@ def from_df(dataset_df: pd.DataFrame, dataset_path: str | pathlib.Path) -> list[ return [annot.seq.labels for annot in annots] -ALPHANUMERIC = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' +ALPHANUMERIC = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" DUMMY_SINGLE_CHAR_LABELS = [ # some large range of characters not typically used as labels - chr(x) for x in range(162, 400) + chr(x) + for x in range(162, 400) ] # start with alphanumeric since more human readable; # mapping can be arbitrary as long as it's consistent -DUMMY_SINGLE_CHAR_LABELS = ( - *ALPHANUMERIC, - *DUMMY_SINGLE_CHAR_LABELS -) +DUMMY_SINGLE_CHAR_LABELS = (*ALPHANUMERIC, *DUMMY_SINGLE_CHAR_LABELS) # added to fix https://github.com/NickleDave/vak/issues/373 -def multi_char_labels_to_single_char(labelmap: dict, skip: tuple[str] = ('unlabeled',)) -> dict: +def multi_char_labels_to_single_char( + labelmap: dict, skip: tuple[str] = ("unlabeled",) +) -> dict: """Return a copy of a ``labelmap`` where any labels that are strings with multiple characters are converted to single characters. diff --git a/src/vak/common/learncurve.py b/src/vak/common/learncurve.py index e912c8936..ee291c4f6 100644 --- a/src/vak/common/learncurve.py +++ b/src/vak/common/learncurve.py @@ -1,4 +1,6 @@ -def get_train_dur_replicate_split_name(train_dur: int, replicate_num: int) -> str: +def get_train_dur_replicate_split_name( + train_dur: int, replicate_num: int +) -> str: """Get name of a training set split for a learning curve, for a specified training duration and replicate number. diff --git a/src/vak/common/logging.py b/src/vak/common/logging.py index 5449dd3d7..8ced29688 100644 --- a/src/vak/common/logging.py +++ b/src/vak/common/logging.py @@ -1,21 +1,18 @@ """utility functions for logging""" import logging -from pathlib import Path import sys import warnings +from pathlib import Path -from . import timenow from ..__about__ import __version__ +from . import timenow - -logger = logging.getLogger('vak') # 'base' logger +logger = logging.getLogger("vak") # 'base' logger -def config_logging_for_cli(log_dst: str, - log_stem: str, - level='info', - timestamp=None, - force=False): +def config_logging_for_cli( + log_dst: str, log_stem: str, level="info", timestamp=None, force=False +): """Configure logging for a run of the cli. Called by `vak.cli` functions. Allows logging @@ -65,7 +62,9 @@ def config_logging_for_cli(log_dst: str, ) return - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) file_handler = logging.FileHandler(logfile_name) stream_handler = logging.StreamHandler(sys.stdout) for handler in (file_handler, stream_handler): @@ -76,6 +75,4 @@ def config_logging_for_cli(log_dst: str, def log_version(logger: logging.Logger) -> None: - logger.info( - f"vak version: {__version__}" - ) + logger.info(f"vak version: {__version__}") diff --git a/src/vak/common/paths.py b/src/vak/common/paths.py index edfdc7923..212ad32fc 100644 --- a/src/vak/common/paths.py +++ b/src/vak/common/paths.py @@ -33,5 +33,7 @@ def generate_results_dir_name_as_path(root_results_dir=None): f"root_results_dir not recognized as a directory: {root_results_dir}" ) - results_dirname = f"{constants.RESULTS_DIR_PREFIX}{timenow.get_timenow_as_str()}" + results_dirname = ( + f"{constants.RESULTS_DIR_PREFIX}{timenow.get_timenow_as_str()}" + ) return root_results_dir.joinpath(results_dirname) diff --git a/src/vak/common/tensorboard.py b/src/vak/common/tensorboard.py index b9d355b0b..159853c38 100644 --- a/src/vak/common/tensorboard.py +++ b/src/vak/common/tensorboard.py @@ -2,8 +2,10 @@ from pathlib import Path import pandas as pd +from tensorboard.backend.event_processing.event_accumulator import ( + EventAccumulator, +) from torch.utils.tensorboard import SummaryWriter -from tensorboard.backend.event_processing.event_accumulator import EventAccumulator def get_summary_writer(log_dir, filename_suffix): @@ -99,14 +101,19 @@ def events2df(events_path, size_guidance=None, drop_wall_time=True): ea = EventAccumulator(path=events_path, size_guidance=size_guidance) ea.Reload() # load all data written so far - scalar_tags = ea.Tags()["scalars"] # list of tags for values written to scalar + scalar_tags = ea.Tags()[ + "scalars" + ] # list of tags for values written to scalar # make a dataframe for each tag, which we will then concatenate using 'step' as the index # so that pandas will fill in with NaNs for any scalars that were not measured on every step dfs = {} for scalar_tag in scalar_tags: dfs[scalar_tag] = pd.DataFrame( - [(scalar.wall_time, scalar.step, scalar.value) for scalar in ea.Scalars(scalar_tag)], - columns=["wall_time", "step", scalar_tag] + [ + (scalar.wall_time, scalar.step, scalar.value) + for scalar in ea.Scalars(scalar_tag) + ], + columns=["wall_time", "step", scalar_tag], ).set_index("step") if drop_wall_time: dfs[scalar_tag].drop("wall_time", axis=1, inplace=True) diff --git a/src/vak/common/timebins.py b/src/vak/common/timebins.py index 6900e2794..dd1d8375a 100644 --- a/src/vak/common/timebins.py +++ b/src/vak/common/timebins.py @@ -26,8 +26,10 @@ def timebin_dur_from_vec(time_bins, n_decimals_trunc=5): to deal with floating point error, then rounds and truncates to specified decimal place """ # first we round to the given number of decimals - timebin_dur = np.around(np.mean(np.diff(time_bins)), decimals=n_decimals_trunc) + timebin_dur = np.around( + np.mean(np.diff(time_bins)), decimals=n_decimals_trunc + ) # only after rounding do we truncate any decimal place past decade - decade = 10 ** n_decimals_trunc + decade = 10**n_decimals_trunc timebin_dur = np.trunc(timebin_dur * decade) / decade return timebin_dur diff --git a/src/vak/common/timenow.py b/src/vak/common/timenow.py index 8aa2977e9..f838997f3 100644 --- a/src/vak/common/timenow.py +++ b/src/vak/common/timenow.py @@ -4,6 +4,6 @@ def get_timenow_as_str(): - f"""returns current time as a string, + """Returns current time as a string, with the format specified by ``vak.constants.STRFTIME_TIMESTAMP``""" return datetime.now().strftime(STRFTIME_TIMESTAMP) diff --git a/src/vak/common/trainer.py b/src/vak/common/trainer.py index 20ea5fb29..439d25034 100644 --- a/src/vak/common/trainer.py +++ b/src/vak/common/trainer.py @@ -1,51 +1,53 @@ from __future__ import annotations + import pathlib import pytorch_lightning as lightning -def get_default_train_callbacks(ckpt_root: str | pathlib.Path, - ckpt_step: int, - patience: int, ): +def get_default_train_callbacks( + ckpt_root: str | pathlib.Path, + ckpt_step: int, + patience: int, +): ckpt_callback = lightning.callbacks.ModelCheckpoint( dirpath=ckpt_root, - filename='checkpoint', + filename="checkpoint", every_n_train_steps=ckpt_step, save_last=True, verbose=True, ) - ckpt_callback.CHECKPOINT_NAME_LAST = 'checkpoint' - ckpt_callback.FILE_EXTENSION = '.pt' + ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint" + ckpt_callback.FILE_EXTENSION = ".pt" val_ckpt_callback = lightning.callbacks.ModelCheckpoint( monitor="val_acc", dirpath=ckpt_root, save_top_k=1, - mode='max', - filename='max-val-acc-checkpoint', + mode="max", + filename="max-val-acc-checkpoint", auto_insert_metric_name=False, - verbose=True + verbose=True, ) - val_ckpt_callback.FILE_EXTENSION = '.pt' + val_ckpt_callback.FILE_EXTENSION = ".pt" early_stopping = lightning.callbacks.EarlyStopping( - mode='max', - monitor='val_acc', + mode="max", + monitor="val_acc", patience=patience, verbose=True, ) - return [ckpt_callback, - val_ckpt_callback, - early_stopping] + return [ckpt_callback, val_ckpt_callback, early_stopping] -def get_default_trainer(max_steps: int, - log_save_dir: str | pathlib.Path, - val_step: int, - default_callback_kwargs: dict | None = None, - device: str = 'cuda', - ) -> lightning.Trainer: +def get_default_trainer( + max_steps: int, + log_save_dir: str | pathlib.Path, + val_step: int, + default_callback_kwargs: dict | None = None, + device: str = "cuda", +) -> lightning.Trainer: """Returns an instance of ``lightning.Trainer`` with a default set of callbacks. Used by ``vak.core`` functions.""" @@ -54,14 +56,12 @@ def get_default_trainer(max_steps: int, else: callbacks = None - if device == 'cuda': - accelerator = 'gpu' + if device == "cuda": + accelerator = "gpu" else: accelerator = None - logger = lightning.loggers.TensorBoardLogger( - save_dir=log_save_dir - ) + logger = lightning.loggers.TensorBoardLogger(save_dir=log_save_dir) trainer = lightning.Trainer( callbacks=callbacks, diff --git a/src/vak/config/__init__.py b/src/vak/config/__init__.py index 33172eabd..f763e238a 100644 --- a/src/vak/config/__init__.py +++ b/src/vak/config/__init__.py @@ -11,3 +11,17 @@ train, validators, ) + + +__all__ = [ + "config", + "eval", + "learncurve", + "model", + "parse", + "predict", + "prep", + "spect_params", + "train", + "validators", +] diff --git a/src/vak/config/config.py b/src/vak/config/config.py index 07ad130e2..377802b3b 100644 --- a/src/vak/config/config.py +++ b/src/vak/config/config.py @@ -28,13 +28,16 @@ class Config: learncurve : vak.config.learncurve.LearncurveConfig represents ``[LEARNCURVE]`` section of config.toml file """ + spect_params = attr.ib( validator=instance_of(SpectParamsConfig), default=SpectParamsConfig() ) prep = attr.ib(validator=optional(instance_of(PrepConfig)), default=None) train = attr.ib(validator=optional(instance_of(TrainConfig)), default=None) eval = attr.ib(validator=optional(instance_of(EvalConfig)), default=None) - predict = attr.ib(validator=optional(instance_of(PredictConfig)), default=None) + predict = attr.ib( + validator=optional(instance_of(PredictConfig)), default=None + ) learncurve = attr.ib( validator=optional(instance_of(LearncurveConfig)), default=None ) diff --git a/src/vak/config/eval.py b/src/vak/config/eval.py index f6bac1922..6991b89a7 100644 --- a/src/vak/config/eval.py +++ b/src/vak/config/eval.py @@ -3,26 +3,30 @@ from attr import converters, validators from attr.validators import instance_of -from .validators import is_valid_model_name from ..common import device from ..common.converters import expanded_user_path +from .validators import is_valid_model_name def convert_post_tfm_kwargs(post_tfm_kwargs: dict) -> dict: post_tfm_kwargs = dict(post_tfm_kwargs) - if 'min_segment_dur' not in post_tfm_kwargs: + if "min_segment_dur" not in post_tfm_kwargs: # because there's no null in TOML, # users leave arg out of config then we set it to None - post_tfm_kwargs['min_segment_dur'] = None + post_tfm_kwargs["min_segment_dur"] = None else: - post_tfm_kwargs['min_segment_dur'] = float(post_tfm_kwargs['min_segment_dur']) + post_tfm_kwargs["min_segment_dur"] = float( + post_tfm_kwargs["min_segment_dur"] + ) - if 'majority_vote' not in post_tfm_kwargs: + if "majority_vote" not in post_tfm_kwargs: # set default for this one too - post_tfm_kwargs['majority_vote'] = False + post_tfm_kwargs["majority_vote"] = False else: - post_tfm_kwargs['majority_vote'] = bool(post_tfm_kwargs['majority_vote']) + post_tfm_kwargs["majority_vote"] = bool( + post_tfm_kwargs["majority_vote"] + ) return post_tfm_kwargs @@ -36,22 +40,27 @@ def are_valid_post_tfm_kwargs(instance, attribute, value): "Please declare in a similar fashion: `{majority_vote = True, min_segment_dur = 0.02}`" ) if any( - [k not in {'majority_vote', 'min_segment_dur'} for k in value.keys()] + [k not in {"majority_vote", "min_segment_dur"} for k in value.keys()] ): - invalid_kwargs = [k for k in value.keys() - if k not in {'majority_vote', 'min_segment_dur'}] + invalid_kwargs = [ + k + for k in value.keys() + if k not in {"majority_vote", "min_segment_dur"} + ] raise ValueError( f"Invalid keyword argument name specified for 'post_tfm_kwargs': {invalid_kwargs}." "Valid names are: {'majority_vote', 'min_segment_dur'}" ) - if 'majority_vote' in value: - if not isinstance(value['majority_vote'], bool): + if "majority_vote" in value: + if not isinstance(value["majority_vote"], bool): raise TypeError( "'post_tfm_kwargs' keyword argument 'majority_vote' " f"should be of type bool but was: {type(value['majority_vote'])}" ) - if 'min_segment_dur' in value: - if value['min_segment_dur'] and not isinstance(value['min_segment_dur'], float): + if "min_segment_dur" in value: + if value["min_segment_dur"] and not isinstance( + value["min_segment_dur"], float + ): raise TypeError( "'post_tfm_kwargs' keyword argument 'min_segment_dur' type " f"should be float but was: {type(value['min_segment_dur'])}" @@ -108,6 +117,7 @@ class EvalConfig: Passed as keyword arguments. Optional, default is None. """ + # required, external files checkpoint_path = attr.ib(converter=expanded_user_path) output_dir = attr.ib(converter=expanded_user_path) @@ -128,7 +138,9 @@ class EvalConfig: # "optional" but actually required for frame classification models # TODO: check model family in __post_init__ and raise ValueError if labelmap # TODO: not specified for a frame classification model? - labelmap_path = attr.ib(converter=converters.optional(expanded_user_path), default=None) + labelmap_path = attr.ib( + converter=converters.optional(expanded_user_path), default=None + ) # optional, transform spect_scaler_path = attr.ib( converter=converters.optional(expanded_user_path), diff --git a/src/vak/config/learncurve.py b/src/vak/config/learncurve.py index 0b738876d..13fc6021a 100644 --- a/src/vak/config/learncurve.py +++ b/src/vak/config/learncurve.py @@ -50,6 +50,7 @@ class LearncurveConfig(TrainConfig): See the docstring of the transform for more details on these arguments and how they work. """ + post_tfm_kwargs = attr.ib( validator=validators.optional(are_valid_post_tfm_kwargs), converter=converters.optional(convert_post_tfm_kwargs), diff --git a/src/vak/config/model.py b/src/vak/config/model.py index 3426cd50f..d1d36643b 100644 --- a/src/vak/config/model.py +++ b/src/vak/config/model.py @@ -1,17 +1,17 @@ from __future__ import annotations + import pathlib import toml from .. import models - MODEL_TABLES = [ - "network", - "optimizer", - "loss", - "metrics", - ] + "network", + "optimizer", + "loss", + "metrics", +] def config_from_toml_dict(toml_dict: dict, model_name: str) -> dict: @@ -57,7 +57,9 @@ def config_from_toml_dict(toml_dict: dict, model_name: str) -> dict: return model_config -def config_from_toml_path(toml_path: str | pathlib.Path, model_name: str) -> dict: +def config_from_toml_path( + toml_path: str | pathlib.Path, model_name: str +) -> dict: """Get configuration for a model from a .toml configuration file, given the path to the file. diff --git a/src/vak/config/parse.py b/src/vak/config/parse.py index c69183a73..295a1a537 100644 --- a/src/vak/config/parse.py +++ b/src/vak/config/parse.py @@ -10,8 +10,7 @@ from .prep import PrepConfig from .spect_params import SpectParamsConfig from .train import TrainConfig -from .validators import are_sections_valid, are_options_valid - +from .validators import are_options_valid, are_sections_valid SECTION_CLASSES = { "EVAL": EvalConfig, @@ -92,10 +91,17 @@ def _validate_sections_arg_convert_list(sections): if isinstance(sections, str): sections = [sections] elif isinstance(sections, list): - if not all([isinstance(section_name, str) for section_name in sections]): - raise ValueError("all section names in 'sections' should be strings") if not all( - [section_name in list(SECTION_CLASSES.keys()) for section_name in sections] + [isinstance(section_name, str) for section_name in sections] + ): + raise ValueError( + "all section names in 'sections' should be strings" + ) + if not all( + [ + section_name in list(SECTION_CLASSES.keys()) + for section_name in sections + ] ): raise ValueError( "all section names in 'sections' should be valid names of sections. " @@ -162,7 +168,9 @@ def _load_toml_from_path(toml_path): with toml_path.open("r") as fp: config_toml = toml.load(fp) except TomlDecodeError as e: - raise Exception(f"Error when parsing .toml config file: {toml_path}") from e + raise Exception( + f"Error when parsing .toml config file: {toml_path}" + ) from e return config_toml diff --git a/src/vak/config/predict.py b/src/vak/config/predict.py index fb221c82e..852605165 100644 --- a/src/vak/config/predict.py +++ b/src/vak/config/predict.py @@ -6,9 +6,9 @@ from attr import converters, validators from attr.validators import instance_of -from .validators import is_valid_model_name from ..common import device from ..common.converters import expanded_user_path +from .validators import is_valid_model_name @attr.s diff --git a/src/vak/config/prep.py b/src/vak/config/prep.py index 19bfc7d1a..717b7b445 100644 --- a/src/vak/config/prep.py +++ b/src/vak/config/prep.py @@ -2,17 +2,13 @@ import inspect import attr +import dask.bag from attr import converters, validators from attr.validators import instance_of -import dask.bag -from .validators import ( - is_audio_format, - is_annot_format, - is_spect_format, -) -from ..common.converters import expanded_user_path, labelset_to_set from .. import prep +from ..common.converters import expanded_user_path, labelset_to_set +from .validators import is_annot_format, is_audio_format, is_spect_format def duration_from_toml_value(value): @@ -56,9 +52,11 @@ def are_valid_dask_bag_kwargs(instance, attribute, value): inspect.signature(dask.bag.from_sequence).parameters.keys() ) if not all([kwarg in valid_bag_kwargs for kwarg in kwargs]): - invalid_kwargs = [kwarg for kwarg in kwargs if kwarg not in valid_bag_kwargs] + invalid_kwargs = [ + kwarg for kwarg in kwargs if kwarg not in valid_bag_kwargs + ] print( - f'Invalid keyword arguments specified in ``audio_dask_bag_kwargs``: {invalid_kwargs}' + f"Invalid keyword arguments specified in ``audio_dask_bag_kwargs``: {invalid_kwargs}" ) @@ -120,10 +118,12 @@ class PrepConfig: randomly drawn subset of the training data (but of the same duration). Default is None. Required if config file has a learncurve section. """ + data_dir = attr.ib(converter=expanded_user_path) output_dir = attr.ib(converter=expanded_user_path) dataset_type = attr.ib(validator=instance_of(str)) + @dataset_type.validator def is_valid_dataset_type(self, attribute, value): if value not in prep.constants.DATASET_TYPES: @@ -133,6 +133,7 @@ def is_valid_dataset_type(self, attribute, value): ) input_type = attr.ib(validator=instance_of(str)) + @input_type.validator def is_valid_input_type(self, attribute, value): if value not in prep.constants.INPUT_TYPES: @@ -140,13 +141,19 @@ def is_valid_input_type(self, attribute, value): f"Invalid input type: {value}. Must be one of: {prep.constants.INPUT_TYPES}" ) - audio_format = attr.ib(validator=validators.optional(is_audio_format), default=None) - spect_format = attr.ib(validator=validators.optional(is_spect_format), default=None) + audio_format = attr.ib( + validator=validators.optional(is_audio_format), default=None + ) + spect_format = attr.ib( + validator=validators.optional(is_spect_format), default=None + ) annot_file = attr.ib( converter=converters.optional(expanded_user_path), default=None, ) - annot_format = attr.ib(validator=validators.optional(is_annot_format), default=None) + annot_format = attr.ib( + validator=validators.optional(is_annot_format), default=None + ) labelset = attr.ib( converter=converters.optional(labelset_to_set), @@ -154,7 +161,9 @@ def is_valid_input_type(self, attribute, value): default=None, ) - audio_dask_bag_kwargs = attr.ib(validator=validators.optional(are_valid_dask_bag_kwargs), default=None) + audio_dask_bag_kwargs = attr.ib( + validator=validators.optional(are_valid_dask_bag_kwargs), default=None + ) train_dur = attr.ib( converter=converters.optional(duration_from_toml_value), @@ -171,12 +180,18 @@ def is_valid_input_type(self, attribute, value): validator=validators.optional(is_valid_duration), default=None, ) - train_set_durs = attr.ib(validator=validators.optional(instance_of(list)), default=None) - num_replicates = attr.ib(validator=validators.optional(instance_of(int)), default=None) + train_set_durs = attr.ib( + validator=validators.optional(instance_of(list)), default=None + ) + num_replicates = attr.ib( + validator=validators.optional(instance_of(int)), default=None + ) def __attrs_post_init__(self): if self.audio_format is not None and self.spect_format is not None: - raise ValueError(f"cannot specify audio_format and spect_format") + raise ValueError("cannot specify audio_format and spect_format") if self.audio_format is None and self.spect_format is None: - raise ValueError(f"must specify either audio_format or spect_format") + raise ValueError( + "must specify either audio_format or spect_format" + ) diff --git a/src/vak/config/spect_params.py b/src/vak/config/spect_params.py index 363550cd2..4a61942a6 100644 --- a/src/vak/config/spect_params.py +++ b/src/vak/config/spect_params.py @@ -70,7 +70,9 @@ class SpectParamsConfig: default=None, ) transform_type = attr.ib( - validator=validators.optional([instance_of(str), is_valid_transform_type]), + validator=validators.optional( + [instance_of(str), is_valid_transform_type] + ), default=None, ) spect_key = attr.ib(validator=instance_of(str), default="s") diff --git a/src/vak/config/train.py b/src/vak/config/train.py index 6bb9d4468..034a110b4 100644 --- a/src/vak/config/train.py +++ b/src/vak/config/train.py @@ -3,9 +3,9 @@ from attr import converters, validators from attr.validators import instance_of -from .validators import is_valid_model_name from ..common import device from ..common.converters import bool_from_str, expanded_user_path +from .validators import is_valid_model_name @attr.s @@ -55,13 +55,14 @@ class TrainConfig: validation set improving before stopping the training. Default is None, in which case training only stops after the specified number of epochs. checkpoint_path : str - path to directory with checkpoint files saved by Torch, to reload model. - Default is None, in which case a new model is initialized. + path to directory with checkpoint files saved by Torch, to reload model. + Default is None, in which case a new model is initialized. spect_scaler_path : str path to a saved SpectScaler object used to normalize spectrograms. If spectrograms were normalized and this is not provided, will give incorrect results. Default is None. """ + # required model = attr.ib( validator=[instance_of(str), is_valid_model_name], diff --git a/src/vak/config/validators.py b/src/vak/config/validators.py index 60778ca77..71d757d10 100644 --- a/src/vak/config/validators.py +++ b/src/vak/config/validators.py @@ -75,7 +75,9 @@ def are_sections_valid(config_dict, toml_path=None): command for command in CLI_COMMANDS if command != "prep" ] sections_that_are_commands_besides_prep = [ - section for section in sections if section.lower() in cli_commands_besides_prep + section + for section in sections + if section.lower() in cli_commands_besides_prep ] if len(sections_that_are_commands_besides_prep) == 0: raise ValueError( @@ -94,11 +96,18 @@ def are_sections_valid(config_dict, toml_path=None): # add model names to valid sections so users can define model config in sections valid_sections = VALID_SECTIONS + MODEL_NAMES for section in sections: - if section not in valid_sections and f"{section}Model" not in valid_sections: + if ( + section not in valid_sections + and f"{section}Model" not in valid_sections + ): if toml_path: - err_msg = f"section defined in {toml_path} is not valid: {section}" + err_msg = ( + f"section defined in {toml_path} is not valid: {section}" + ) else: - err_msg = f"section defined in toml config is not valid: {section}" + err_msg = ( + f"section defined in toml config is not valid: {section}" + ) raise ValueError(err_msg) diff --git a/src/vak/datasets/__init__.py b/src/vak/datasets/__init__.py index 55dc89254..fdbbe9cbc 100644 --- a/src/vak/datasets/__init__.py +++ b/src/vak/datasets/__init__.py @@ -1,10 +1,7 @@ -from . import ( - frame_classification, - parametric_umap -) +from . import frame_classification, parametric_umap __all__ = [ - "dimensionality_reduction", "frame_classification", + "parametric_umap" ] diff --git a/src/vak/datasets/frame_classification/__init__.py b/src/vak/datasets/frame_classification/__init__.py index b9be94832..98eda614b 100644 --- a/src/vak/datasets/frame_classification/__init__.py +++ b/src/vak/datasets/frame_classification/__init__.py @@ -3,10 +3,4 @@ from .metadata import Metadata from .window_dataset import WindowDataset - -__all__ = [ - "constants", - "Metadata", - "FramesDataset", - "WindowDataset" -] +__all__ = ["constants", "Metadata", "FramesDataset", "WindowDataset"] diff --git a/src/vak/datasets/frame_classification/constants.py b/src/vak/datasets/frame_classification/constants.py index 89e6f0dce..0ec942562 100644 --- a/src/vak/datasets/frame_classification/constants.py +++ b/src/vak/datasets/frame_classification/constants.py @@ -1,8 +1,8 @@ -FRAMES_ARRAY_EXT = '.frames.npy' -FRAMES_NPY_PATH_COL_NAME = 'frames_npy_path' -FRAME_LABELS_EXT = '.frame_labels.npy' -FRAME_LABELS_NPY_PATH_COL_NAME = 'frame_labels_npy_path' -ANNOTATION_CSV_FILENAME = 'y.csv' -SAMPLE_IDS_ARRAY_FILENAME = 'sample_ids.npy' -INDS_IN_SAMPLE_ARRAY_FILENAME = 'inds_in_sample.npy' -WINDOW_INDS_ARRAY_FILENAME = 'window_inds.npy' +FRAMES_ARRAY_EXT = ".frames.npy" +FRAMES_NPY_PATH_COL_NAME = "frames_npy_path" +FRAME_LABELS_EXT = ".frame_labels.npy" +FRAME_LABELS_NPY_PATH_COL_NAME = "frame_labels_npy_path" +ANNOTATION_CSV_FILENAME = "y.csv" +SAMPLE_IDS_ARRAY_FILENAME = "sample_ids.npy" +INDS_IN_SAMPLE_ARRAY_FILENAME = "inds_in_sample.npy" +WINDOW_INDS_ARRAY_FILENAME = "window_inds.npy" diff --git a/src/vak/datasets/frame_classification/frames_dataset.py b/src/vak/datasets/frame_classification/frames_dataset.py index b66a640c2..6713ba4ca 100644 --- a/src/vak/datasets/frame_classification/frames_dataset.py +++ b/src/vak/datasets/frame_classification/frames_dataset.py @@ -27,6 +27,7 @@ class FramesDataset: duration : float Total duration of the dataset. """ + def __init__( self, dataset_path: str | pathlib.Path, @@ -43,16 +44,20 @@ def __init__( self.split = split dataset_df = dataset_df[dataset_df.split == split].copy() self.dataset_df = dataset_df - self.frames_paths = self.dataset_df[constants.FRAMES_NPY_PATH_COL_NAME].values - if split != 'predict': - self.frame_labels_paths = self.dataset_df[constants.FRAME_LABELS_NPY_PATH_COL_NAME].values + self.frames_paths = self.dataset_df[ + constants.FRAMES_NPY_PATH_COL_NAME + ].values + if split != "predict": + self.frame_labels_paths = self.dataset_df[ + constants.FRAME_LABELS_NPY_PATH_COL_NAME + ].values else: self.frame_labels_paths = None - if input_type == 'audio': - self.source_paths = self.dataset_df['audio_path'].values - elif input_type == 'spect': - self.source_paths = self.dataset_df['spect_path'].values + if input_type == "audio": + self.source_paths = self.dataset_df["audio_path"].values + elif input_type == "spect": + self.source_paths = self.dataset_df["spect_path"].values else: raise ValueError( f"Invalid `input_type`: {input_type}. Must be one of {{'audio', 'spect'}}." @@ -76,13 +81,12 @@ def shape(self): def __getitem__(self, idx): source_path = self.source_paths[idx] frames = np.load(self.dataset_path / self.frames_paths[idx]) - item = { - 'frames': frames, - 'source_path': source_path - } + item = {"frames": frames, "source_path": source_path} if self.frame_labels_paths is not None: - frame_labels = np.load(self.dataset_path / self.frame_labels_paths[idx]) - item['frame_labels'] = frame_labels + frame_labels = np.load( + self.dataset_path / self.frame_labels_paths[idx] + ) + item["frame_labels"] = frame_labels if self.item_transform: item = self.item_transform(**item) @@ -123,7 +127,9 @@ def from_dataset_path( split_path = dataset_path / split sample_ids_path = split_path / constants.SAMPLE_IDS_ARRAY_FILENAME sample_ids = np.load(sample_ids_path) - inds_in_sample_path = split_path / constants.INDS_IN_SAMPLE_ARRAY_FILENAME + inds_in_sample_path = ( + split_path / constants.INDS_IN_SAMPLE_ARRAY_FILENAME + ) inds_in_sample = np.load(inds_in_sample_path) return cls( diff --git a/src/vak/datasets/frame_classification/metadata.py b/src/vak/datasets/frame_classification/metadata.py index 20542c31b..074e5d26b 100644 --- a/src/vak/datasets/frame_classification/metadata.py +++ b/src/vak/datasets/frame_classification/metadata.py @@ -12,19 +12,20 @@ def is_valid_dataset_csv_filename(instance, attribute, value): - valid = '_prep_' in value and value.endswith('.csv') + valid = "_prep_" in value and value.endswith(".csv") if not valid: raise ValueError( - f'Invalid dataset csv filename: {value}.' + f"Invalid dataset csv filename: {value}." f'Filename should contain the string "_prep_" ' - f'and end with the extension .csv.' - f'Valid filenames are generated by ' - f'vak.core.prep.generate_dataset_csv_filename' + f"and end with the extension .csv." + f"Valid filenames are generated by " + f"vak.core.prep.generate_dataset_csv_filename" ) def is_valid_audio_format(instance, attribute, value): import vak.common.constants + if value not in vak.common.constants.VALID_AUDIO_FORMATS: raise ValueError( f"Not a valid audio format: {value}. Valid audio formats are: {vak.common.constants.VALID_AUDIO_FORMATS}" @@ -33,6 +34,7 @@ def is_valid_audio_format(instance, attribute, value): def is_valid_spect_format(instance, attribute, value): import vak.common.constants + if value not in vak.common.constants.VALID_SPECT_FORMATS: raise ValueError( f"Not a valid spectrogram format: {value}. " @@ -59,13 +61,17 @@ class Metadata: The modality of the input data "frames", either audio signals or spectrograms. One of {'audio', 'spect'}. """ + # declare this as a constant to avoid # needing to remember this in multiple places, and to use in unit tests - METADATA_JSON_FILENAME: ClassVar = 'metadata.json' + METADATA_JSON_FILENAME: ClassVar = "metadata.json" - dataset_csv_filename: str = attr.field(converter=str, validator=is_valid_dataset_csv_filename) + dataset_csv_filename: str = attr.field( + converter=str, validator=is_valid_dataset_csv_filename + ) input_type: str = attr.field() + @input_type.validator def is_valid_input_type(self, attribute, value): if not isinstance(value, str): @@ -73,34 +79,32 @@ def is_valid_input_type(self, attribute, value): f"{attribute.name} value should be a string but was type {type(value)}" ) from ...prep.constants import INPUT_TYPES - if not value in INPUT_TYPES: + + if value not in INPUT_TYPES: raise ValueError( f"Value for {attribute.name} is not a valid input type: '{value}'\n" f"Valid input types are: {INPUT_TYPES}" ) frame_dur: float = attr.field(converter=float) + @frame_dur.validator def is_valid_frame_dur(self, attribute, value): if not isinstance(value, float): - raise ValueError( - f"{attribute.name} should be a float value." - ) - if not value > 0.: - raise ValueError( - f"{attribute.name} should be greater than zero." - ) + raise ValueError(f"{attribute.name} should be a float value.") + if not value > 0.0: + raise ValueError(f"{attribute.name} should be greater than zero.") audio_format: str = attr.field( converter=attr.converters.optional(str), validator=attr.validators.optional(is_valid_audio_format), - default=None + default=None, ) spect_format: str = attr.field( converter=attr.converters.optional(str), validator=attr.validators.optional(is_valid_spect_format), - default=None + default=None, ) @classmethod @@ -124,7 +128,7 @@ def from_path(cls, json_path: str | pathlib.Path): with metadata loaded from json file. """ json_path = pathlib.Path(json_path) - with json_path.open('r') as fp: + with json_path.open("r") as fp: metadata_json = json.load(fp) return cls(**metadata_json) @@ -134,7 +138,7 @@ def from_dataset_path(cls, dataset_path: str | pathlib.Path): if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" - ) + ) metadata_json_path = dataset_path / cls.METADATA_JSON_FILENAME if not metadata_json_path.exists(): @@ -163,10 +167,10 @@ def to_json(self, dataset_path: str | pathlib.Path) -> None: dataset_path = pathlib.Path(dataset_path) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( - f'dataset_path not recognized as a directory: {dataset_path}' + f"dataset_path not recognized as a directory: {dataset_path}" ) json_dict = attr.asdict(self) json_path = dataset_path / self.METADATA_JSON_FILENAME - with json_path.open('w') as fp: + with json_path.open("w") as fp: json.dump(json_dict, fp, indent=4) diff --git a/src/vak/datasets/frame_classification/window_dataset.py b/src/vak/datasets/frame_classification/window_dataset.py index d6397d44b..aa61d37ff 100644 --- a/src/vak/datasets/frame_classification/window_dataset.py +++ b/src/vak/datasets/frame_classification/window_dataset.py @@ -95,18 +95,18 @@ class WindowDataset: """ def __init__( - self, - dataset_path: str | pathlib.Path, - dataset_df: pd.DataFrame, - split: str, - sample_ids: npt.NDArray, - inds_in_sample: npt.NDArray, - window_size: int, - frame_dur: float, - stride: int = 1, - window_inds: npt.NDArray | None = None, - transform: Callable | None = None, - target_transform: Callable | None = None + self, + dataset_path: str | pathlib.Path, + dataset_df: pd.DataFrame, + split: str, + sample_ids: npt.NDArray, + inds_in_sample: npt.NDArray, + window_size: int, + frame_dur: float, + stride: int = 1, + window_inds: npt.NDArray | None = None, + transform: Callable | None = None, + target_transform: Callable | None = None, ): self.dataset_path = pathlib.Path(dataset_path) @@ -114,8 +114,12 @@ def __init__( dataset_df = dataset_df[dataset_df.split == split].copy() self.dataset_df = dataset_df - self.frames_paths = self.dataset_df[constants.FRAMES_NPY_PATH_COL_NAME].values - self.frame_labels_paths = self.dataset_df[constants.FRAME_LABELS_NPY_PATH_COL_NAME].values + self.frames_paths = self.dataset_df[ + constants.FRAMES_NPY_PATH_COL_NAME + ].values + self.frame_labels_paths = self.dataset_df[ + constants.FRAME_LABELS_NPY_PATH_COL_NAME + ].values self.sample_ids = sample_ids self.inds_in_sample = inds_in_sample @@ -125,7 +129,9 @@ def __init__( self.stride = stride if window_inds is None: - window_inds = get_window_inds(sample_ids.shape[-1], window_size, stride) + window_inds = get_window_inds( + sample_ids.shape[-1], window_size, stride + ) self.window_inds = window_inds self.transform = transform @@ -145,12 +151,16 @@ def shape(self): def __getitem__(self, idx): window_idx = self.window_inds[idx] - sample_ids = self.sample_ids[window_idx:window_idx + self.window_size] + sample_ids = self.sample_ids[ + window_idx: window_idx + self.window_size + ] uniq_sample_ids = np.unique(sample_ids) if len(uniq_sample_ids) == 1: sample_id = uniq_sample_ids[0] frames = np.load(self.dataset_path / self.frames_paths[sample_id]) - frame_labels = np.load(self.dataset_path / self.frame_labels_paths[sample_id]) + frame_labels = np.load( + self.dataset_path / self.frame_labels_paths[sample_id] + ) elif len(uniq_sample_ids) > 1: frames = [] frame_labels = [] @@ -159,7 +169,9 @@ def __getitem__(self, idx): np.load(self.dataset_path / self.frames_paths[sample_id]) ) frame_labels.append( - np.load(self.dataset_path / self.frame_labels_paths[sample_id]) + np.load( + self.dataset_path / self.frame_labels_paths[sample_id] + ) ) if all([frames_.ndim == 1 for frames_ in frames]): @@ -174,8 +186,12 @@ def __getitem__(self, idx): ) inds_in_sample = self.inds_in_sample[window_idx] - frames = frames[..., inds_in_sample:inds_in_sample + self.window_size] - frame_labels = frame_labels[inds_in_sample:inds_in_sample + self.window_size] + frames = frames[ + ..., inds_in_sample: inds_in_sample + self.window_size + ] + frame_labels = frame_labels[ + inds_in_sample: inds_in_sample + self.window_size + ] if self.transform: frames = self.transform(frames) if self.target_transform: @@ -189,13 +205,13 @@ def __len__(self): @classmethod def from_dataset_path( - cls, - dataset_path: str | pathlib.Path, - window_size: int, - stride: int = 1, - split: str = "train", - transform: Callable | None = None, - target_transform: Callable | None = None + cls, + dataset_path: str | pathlib.Path, + window_size: int, + stride: int = 1, + split: str = "train", + transform: Callable | None = None, + target_transform: Callable | None = None, ): """ @@ -222,7 +238,9 @@ def from_dataset_path( split_path = dataset_path / split sample_ids_path = split_path / constants.SAMPLE_IDS_ARRAY_FILENAME sample_ids = np.load(sample_ids_path) - inds_in_sample_path = split_path / constants.INDS_IN_SAMPLE_ARRAY_FILENAME + inds_in_sample_path = ( + split_path / constants.INDS_IN_SAMPLE_ARRAY_FILENAME + ) inds_in_sample = np.load(inds_in_sample_path) window_inds_path = split_path / constants.WINDOW_INDS_ARRAY_FILENAME @@ -242,5 +260,5 @@ def from_dataset_path( stride, window_inds, transform, - target_transform + target_transform, ) diff --git a/src/vak/datasets/parametric_umap/__init__.py b/src/vak/datasets/parametric_umap/__init__.py index 90caf061f..42ef1f109 100644 --- a/src/vak/datasets/parametric_umap/__init__.py +++ b/src/vak/datasets/parametric_umap/__init__.py @@ -1,8 +1,4 @@ from .metadata import Metadata from .parametric_umap import ParametricUMAPDataset - -__all__ = [ - 'Metadata', - 'ParametricUMAPDataset' -] +__all__ = ["Metadata", "ParametricUMAPDataset"] diff --git a/src/vak/datasets/parametric_umap/metadata.py b/src/vak/datasets/parametric_umap/metadata.py index 339cab793..ac0b8a137 100644 --- a/src/vak/datasets/parametric_umap/metadata.py +++ b/src/vak/datasets/parametric_umap/metadata.py @@ -12,19 +12,20 @@ def is_valid_dataset_csv_filename(instance, attribute, value): - valid = '_prep_' in value and value.endswith('.csv') + valid = "_prep_" in value and value.endswith(".csv") if not valid: raise ValueError( - f'Invalid dataset csv filename: {value}.' + f"Invalid dataset csv filename: {value}." f'Filename should contain the string "_prep_" ' - f'and end with the extension .csv.' - f'Valid filenames are generated by ' - f'vak.core.prep.generate_dataset_csv_filename' + f"and end with the extension .csv." + f"Valid filenames are generated by " + f"vak.core.prep.generate_dataset_csv_filename" ) def is_valid_audio_format(instance, attribute, value): import vak.common.constants + if value not in vak.common.constants.VALID_AUDIO_FORMATS: raise ValueError( f"Not a valid audio format: {value}. Valid audio formats are: {vak.common.constants.VALID_AUDIO_FORMATS}" @@ -33,6 +34,7 @@ def is_valid_audio_format(instance, attribute, value): def is_valid_spect_format(instance, attribute, value): import vak.common.constants + if value not in vak.common.constants.VALID_SPECT_FORMATS: raise ValueError( f"Not a valid spectrogram format: {value}. " @@ -54,22 +56,24 @@ class Metadata: so only the filename is given. audio_format """ + # declare this as a constant to avoid # needing to remember this in multiple places, and to use in unit tests - METADATA_JSON_FILENAME: ClassVar = 'metadata.json' + METADATA_JSON_FILENAME: ClassVar = "metadata.json" - dataset_csv_filename: str = attr.field(converter=str, validator=is_valid_dataset_csv_filename) + dataset_csv_filename: str = attr.field( + converter=str, validator=is_valid_dataset_csv_filename + ) shape: tuple = attr.field(converter=tuple) + @shape.validator def is_valid_shape(self, attribute, value): if not isinstance(value, tuple): raise TypeError( f"`shape` should be a tuple but type was: {type(value)}" ) - if not all( - [isinstance(val, int) and val > 0 for val in value] - ): + if not all([isinstance(val, int) and val > 0 for val in value]): raise ValueError( f"All values of `shape` should be positive integers but values were: {value}" ) @@ -77,7 +81,7 @@ def is_valid_shape(self, attribute, value): audio_format: str = attr.field( converter=attr.converters.optional(str), validator=attr.validators.optional(is_valid_audio_format), - default=None + default=None, ) @classmethod @@ -101,7 +105,7 @@ def from_path(cls, json_path: str | pathlib.Path): with metadata loaded from json file. """ json_path = pathlib.Path(json_path) - with json_path.open('r') as fp: + with json_path.open("r") as fp: metadata_json = json.load(fp) return cls(**metadata_json) @@ -111,7 +115,7 @@ def from_dataset_path(cls, dataset_path: str | pathlib.Path): if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" - ) + ) metadata_json_path = dataset_path / cls.METADATA_JSON_FILENAME if not metadata_json_path.exists(): @@ -140,10 +144,10 @@ def to_json(self, dataset_path: str | pathlib.Path) -> None: dataset_path = pathlib.Path(dataset_path) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( - f'dataset_path not recognized as a directory: {dataset_path}' + f"dataset_path not recognized as a directory: {dataset_path}" ) json_dict = attr.asdict(self) json_path = dataset_path / self.METADATA_JSON_FILENAME - with json_path.open('w') as fp: + with json_path.open("w") as fp: json.dump(json_dict, fp, indent=4) diff --git a/src/vak/datasets/parametric_umap/parametric_umap.py b/src/vak/datasets/parametric_umap/parametric_umap.py index 9dcbaabca..2281b4d21 100644 --- a/src/vak/datasets/parametric_umap/parametric_umap.py +++ b/src/vak/datasets/parametric_umap/parametric_umap.py @@ -12,18 +12,26 @@ from sklearn.utils import check_random_state from torch.utils.data import Dataset +# isort: off # Ignore warnings from Numba deprecation: # https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit # Numba is required by UMAP. from numba.core.errors import NumbaDeprecationWarning -warnings.simplefilter('ignore', category=NumbaDeprecationWarning) -from umap.umap_ import fuzzy_simplicial_set +warnings.simplefilter("ignore", category=NumbaDeprecationWarning) +from umap.umap_ import fuzzy_simplicial_set # noqa: E402 +# isort: on -def get_umap_graph(X: npt.NDArray, n_neighbors: int = 10, metric: str= "euclidean", - random_state: np.random.RandomState | None = None, - max_candidates: int = 60, verbose: bool = True) -> scipy.sparse._coo.coo_matrix: - """Get graph used by UMAP, + +def get_umap_graph( + X: npt.NDArray, + n_neighbors: int = 10, + metric: str = "euclidean", + random_state: np.random.RandomState | None = None, + max_candidates: int = 60, + verbose: bool = True, +) -> scipy.sparse._coo.coo_matrix: + r"""Get graph used by UMAP, the fuzzy topological representation. Parameters @@ -76,7 +84,9 @@ def get_umap_graph(X: npt.NDArray, n_neighbors: int = 10, metric: str= "euclidea (where :math:`k` is a hyperparameter). In the UMAP package, these are calculated using :func:`umap._umap.smooth_knn_dist`. """ - random_state = check_random_state(None) if random_state == None else random_state + random_state = ( + check_random_state(None) if random_state is None else random_state + ) # number of trees in random projection forest n_trees = 5 + int(round((X.shape[0]) ** 0.5 / 20.0)) @@ -92,7 +102,7 @@ def get_umap_graph(X: npt.NDArray, n_neighbors: int = 10, metric: str= "euclidea n_trees=n_trees, n_iters=n_iters, max_candidates=max_candidates, - verbose=verbose + verbose=verbose, ) # get indices and distances for 10 nearest neighbors of every point in dataset @@ -112,8 +122,15 @@ def get_umap_graph(X: npt.NDArray, n_neighbors: int = 10, metric: str= "euclidea def get_graph_elements( - graph: scipy.sparse._coo.coo_matrix, n_epochs: int -) -> tuple[scipy.sparse._coo.coo_matrix, npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray, int]: + graph: scipy.sparse._coo.coo_matrix, n_epochs: int +) -> tuple[ + scipy.sparse._coo.coo_matrix, + npt.NDArray, + npt.NDArray, + npt.NDArray, + npt.NDArray, + int, +]: """Get graph elements for Parametric UMAP Dataset. Parameters @@ -168,12 +185,24 @@ def get_graph_elements( class ParametricUMAPDataset(Dataset): - """Dataset used for training Parametric UMAP models - - """ - def __init__(self, data: npt.NDArray, graph, - dataset_df: pd.DataFrame, n_epochs: int = 200, transform: Callable | None = None): - graph, epochs_per_sample, head, tail, weight, n_vertices = get_graph_elements(graph, n_epochs) + """Dataset used for training Parametric UMAP models""" + + def __init__( + self, + data: npt.NDArray, + graph, + dataset_df: pd.DataFrame, + n_epochs: int = 200, + transform: Callable | None = None, + ): + ( + graph, + epochs_per_sample, + head, + tail, + weight, + n_vertices, + ) = get_graph_elements(graph, n_epochs) # we repeat each sample in (head, tail) a certain number of times depending on its probability self.edges_to_exp, self.edges_from_exp = ( @@ -183,7 +212,9 @@ def __init__(self, data: npt.NDArray, graph, # we then shuffle -- not sure this is necessary if the dataset is shuffled during training? shuffle_mask = np.random.permutation(np.arange(len(self.edges_to_exp))) self.edges_to_exp = self.edges_to_exp[shuffle_mask].astype(np.int64) - self.edges_from_exp = self.edges_from_exp[shuffle_mask].astype(np.int64) + self.edges_from_exp = self.edges_from_exp[shuffle_mask].astype( + np.int64 + ) self.data = data self.dataset_df = dataset_df @@ -191,7 +222,7 @@ def __init__(self, data: npt.NDArray, graph, @property def duration(self): - return self.dataset_df['duration'].sum() + return self.dataset_df["duration"].sum() def __len__(self): return self.edges_to_exp.shape[0] @@ -211,14 +242,16 @@ def __getitem__(self, index): return (edges_to_exp, edges_from_exp) @classmethod - def from_dataset_path(cls, - dataset_path: str | pathlib.Path, - split: str, - n_neighbors: int = 10, - metric: str = 'euclidean', - random_state: int | None = None, - n_epochs:int = 200, - transform: Callable | None = None): + def from_dataset_path( + cls, + dataset_path: str | pathlib.Path, + split: str, + n_neighbors: int = 10, + metric: str = "euclidean", + random_state: int | None = None, + n_epochs: int = 200, + transform: Callable | None = None, + ): """ Parameters @@ -239,7 +272,9 @@ def from_dataset_path(cls, import vak.datasets # import here just to make classmethod more explicit dataset_path = pathlib.Path(dataset_path) - metadata = vak.datasets.parametric_umap.Metadata.from_dataset_path(dataset_path) + metadata = vak.datasets.parametric_umap.Metadata.from_dataset_path( + dataset_path + ) dataset_csv_path = dataset_path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) @@ -247,10 +282,16 @@ def from_dataset_path(cls, data = np.stack( [ - np.load(dataset_path / spect_path) for spect_path in split_df.spect_path.values + np.load(dataset_path / spect_path) + for spect_path in split_df.spect_path.values ] ) - graph = get_umap_graph(data, n_neighbors=n_neighbors, metric=metric, random_state=random_state) + graph = get_umap_graph( + data, + n_neighbors=n_neighbors, + metric=metric, + random_state=random_state, + ) return cls( data, @@ -262,14 +303,19 @@ def from_dataset_path(cls, class ParametricUMAPInferenceDataset(Dataset): - def __init__(self, data: npt.NDArray, dataset_df: pd.DataFrame, transform: Callable | None = None): + def __init__( + self, + data: npt.NDArray, + dataset_df: pd.DataFrame, + transform: Callable | None = None, + ): self.data = data self.dataset_df = dataset_df self.transform = transform @property def duration(self): - return self.dataset_df['duration'].sum() + return self.dataset_df["duration"].sum() def __len__(self): return self.data.shape[0] @@ -284,18 +330,20 @@ def __getitem__(self, index): x = self.data[index] df_index = self.dataset_df.index[index] if self.transform: - x= self.transform(x) - return {'x': x, 'df_index': df_index} + x = self.transform(x) + return {"x": x, "df_index": df_index} @classmethod - def from_dataset_path(cls, - dataset_path: str | pathlib.Path, - split: str, - n_neighbors: int = 10, - metric: str = 'euclidean', - random_state: int | None = None, - n_epochs:int = 200, - transform: Callable | None = None): + def from_dataset_path( + cls, + dataset_path: str | pathlib.Path, + split: str, + n_neighbors: int = 10, + metric: str = "euclidean", + random_state: int | None = None, + n_epochs: int = 200, + transform: Callable | None = None, + ): """ Parameters @@ -316,7 +364,9 @@ def from_dataset_path(cls, import vak.datasets # import here just to make classmethod more explicit dataset_path = pathlib.Path(dataset_path) - metadata = vak.datasets.parametric_umap.Metadata.from_dataset_path(dataset_path) + metadata = vak.datasets.parametric_umap.Metadata.from_dataset_path( + dataset_path + ) dataset_csv_path = dataset_path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) @@ -324,7 +374,8 @@ def from_dataset_path(cls, data = np.stack( [ - np.load(dataset_path / spect_path) for spect_path in split_df.spect_path.values + np.load(dataset_path / spect_path) + for spect_path in split_df.spect_path.values ] ) return cls( diff --git a/src/vak/eval/__init__.py b/src/vak/eval/__init__.py index 355c39711..67a5fe219 100644 --- a/src/vak/eval/__init__.py +++ b/src/vak/eval/__init__.py @@ -1 +1,9 @@ -from .eval import * +from . import frame_classification, parametric_umap +from .eval import eval + + +__all__ = [ + "eval", + "frame_classification", + "parametric_umap", +] diff --git a/src/vak/eval/eval.py b/src/vak/eval/eval.py index b99be8817..7f57d8f99 100644 --- a/src/vak/eval/eval.py +++ b/src/vak/eval/eval.py @@ -4,13 +4,10 @@ import logging import pathlib +from .. import models +from ..common import validators from .frame_classification import eval_frame_classification_model from .parametric_umap import eval_parametric_umap_model -from .. import ( - models, -) -from ..common import validators - logger = logging.getLogger(__name__) @@ -29,7 +26,7 @@ def eval( split: str = "test", spect_scaler_path: str | pathlib.Path = None, post_tfm_kwargs: dict | None = None, - device: str | None = None + device: str | None = None, ) -> None: """Evaluate a trained model. @@ -99,8 +96,8 @@ def eval( """ # ---- pre-conditions ---------------------------------------------------------------------------------------------- for path, path_name in zip( - (checkpoint_path, labelmap_path, spect_scaler_path), - ('checkpoint_path', 'labelmap_path', 'spect_scaler_path'), + (checkpoint_path, labelmap_path, spect_scaler_path), + ("checkpoint_path", "labelmap_path", "spect_scaler_path"), ): if path is not None: # because `spect_scaler_path` is optional if not validators.is_a_file(path): @@ -151,6 +148,4 @@ def eval( device=device, ) else: - raise ValueError( - f"Model family not recognized: {model_family}" - ) + raise ValueError(f"Model family not recognized: {model_family}") diff --git a/src/vak/eval/frame_classification.py b/src/vak/eval/frame_classification.py index 3b017adec..52188cd7c 100644 --- a/src/vak/eval/frame_classification.py +++ b/src/vak/eval/frame_classification.py @@ -1,26 +1,21 @@ """Function that evaluates trained models in the frame classification family.""" from __future__ import annotations -from collections import OrderedDict -from datetime import datetime import json import logging import pathlib +from collections import OrderedDict +from datetime import datetime import joblib -import pytorch_lightning as lightning import pandas as pd +import pytorch_lightning as lightning import torch.utils.data -from .. import ( - datasets, - models, - transforms, -) +from .. import datasets, models, transforms from ..common import validators from ..datasets.frame_classification import FramesDataset - logger = logging.getLogger(__name__) @@ -103,8 +98,8 @@ def eval_frame_classification_model( """ # ---- pre-conditions ---------------------------------------------------------------------------------------------- for path, path_name in zip( - (checkpoint_path, labelmap_path, spect_scaler_path), - ('checkpoint_path', 'labelmap_path', 'spect_scaler_path'), + (checkpoint_path, labelmap_path, spect_scaler_path), + ("checkpoint_path", "labelmap_path", "spect_scaler_path"), ): if path is not None: # because `spect_scaler_path` is optional if not validators.is_a_file(path): @@ -119,7 +114,9 @@ def eval_frame_classification_model( ) # we unpack `frame_dur` to log it, regardless of whether we use it with post_tfm below - metadata = datasets.frame_classification.Metadata.from_dataset_path(dataset_path) + metadata = datasets.frame_classification.Metadata.from_dataset_path( + dataset_path + ) frame_dur = metadata.frame_dur logger.info( f"Duration of a frame in dataset, in seconds: {frame_dur}", @@ -127,7 +124,7 @@ def eval_frame_classification_model( if not validators.is_a_directory(output_dir): raise NotADirectoryError( - f'value for ``output_dir`` not recognized as a directory: {output_dir}' + f"value for ``output_dir`` not recognized as a directory: {output_dir}" ) # ---- get time for .csv file -------------------------------------------------------------------------------------- @@ -138,7 +135,7 @@ def eval_frame_classification_model( logger.info(f"loading spect scaler from path: {spect_scaler_path}") spect_standardizer = joblib.load(spect_scaler_path) else: - logger.info(f"not using a spect scaler") + logger.info("not using a spect scaler") spect_standardizer = None logger.info(f"loading labelmap from path: {labelmap_path}") @@ -146,11 +143,9 @@ def eval_frame_classification_model( labelmap = json.load(f) if transform_params is None: transform_params = {} - transform_params.update({'spect_standardizer': spect_standardizer}) + transform_params.update({"spect_standardizer": spect_standardizer}) item_transform = transforms.defaults.get_default_transform( - model_name, - "eval", - transform_params + model_name, "eval", transform_params ) if dataset_params is None: dataset_params = {} @@ -196,24 +191,20 @@ def eval_frame_classification_model( model.load_state_dict_from_path(checkpoint_path) - if device == 'cuda': - accelerator = 'gpu' + if device == "cuda": + accelerator = "gpu" else: accelerator = None - trainer_logger = lightning.loggers.TensorBoardLogger( - save_dir=output_dir - ) + trainer_logger = lightning.loggers.TensorBoardLogger(save_dir=output_dir) trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger) # TODO: check for hasattr(model, test_step) and if so run test # below, [0] because validate returns list of dicts, length of no. of val loaders metric_vals = trainer.validate(model, dataloaders=val_loader)[0] - metric_vals = {f'avg_{k}': v for k, v in metric_vals.items()} + metric_vals = {f"avg_{k}": v for k, v in metric_vals.items()} for metric_name, metric_val in metric_vals.items(): - if metric_name.startswith('avg_'): - logger.info( - f'{metric_name}: {metric_val:0.5f}' - ) + if metric_name.startswith("avg_"): + logger.info(f"{metric_name}: {metric_val:0.5f}") # create a "DataFrame" with just one row which we will save as a csv; # the idea is to be able to concatenate csvs from multiple runs of eval @@ -229,7 +220,9 @@ def eval_frame_classification_model( # TODO: is this still necessary after switching to Lightning? Stop saying "average"? # order metrics by name to be extra sure they will be consistent across runs row.update( - sorted([(k, v) for k, v in metric_vals.items() if k.startswith("avg_")]) + sorted( + [(k, v) for k, v in metric_vals.items() if k.startswith("avg_")] + ) ) # pass index into dataframe, needed when using all scalar values (a single row) diff --git a/src/vak/eval/parametric_umap.py b/src/vak/eval/parametric_umap.py index 083fef8cb..594add11a 100644 --- a/src/vak/eval/parametric_umap.py +++ b/src/vak/eval/parametric_umap.py @@ -1,24 +1,19 @@ """Function that evaluates trained models in the parametric UMAP family.""" from __future__ import annotations -from collections import OrderedDict -from datetime import datetime import logging import pathlib +from collections import OrderedDict +from datetime import datetime -import pytorch_lightning as lightning import pandas as pd +import pytorch_lightning as lightning import torch.utils.data -from .. import ( - datasets, - models, - transforms, -) +from .. import datasets, models, transforms from ..common import validators from ..datasets.parametric_umap import ParametricUMAPDataset - logger = logging.getLogger(__name__) @@ -73,8 +68,8 @@ def eval_parametric_umap_model( """ # ---- pre-conditions ---------------------------------------------------------------------------------------------- for path, path_name in zip( - (checkpoint_path,), - ('checkpoint_path',), + (checkpoint_path,), + ("checkpoint_path",), ): if path is not None: # because `spect_scaler_path` is optional if not validators.is_a_file(path): @@ -90,11 +85,13 @@ def eval_parametric_umap_model( logger.info( f"Loading metadata from dataset path: {dataset_path}", ) - metadata = datasets.parametric_umap.Metadata.from_dataset_path(dataset_path) + metadata = datasets.parametric_umap.Metadata.from_dataset_path( + dataset_path + ) if not validators.is_a_directory(output_dir): raise NotADirectoryError( - f'value for ``output_dir`` not recognized as a directory: {output_dir}' + f"value for ``output_dir`` not recognized as a directory: {output_dir}" ) # ---- get time for .csv file -------------------------------------------------------------------------------------- @@ -103,14 +100,12 @@ def eval_parametric_umap_model( # ---------------- load data for evaluation ------------------------------------------------------------------------ if transform_params is None: transform_params = {} - if 'padding' not in transform_params and model_name == 'ConvEncoderUMAP': + if "padding" not in transform_params and model_name == "ConvEncoderUMAP": padding = models.convencoder_umap.get_default_padding(metadata.shape) - transform_params['padding'] = padding + transform_params["padding"] = padding item_transform = transforms.defaults.get_default_transform( - model_name, - "eval", - transform_params + model_name, "eval", transform_params ) if dataset_params is None: dataset_params = {} @@ -138,22 +133,18 @@ def eval_parametric_umap_model( model.load_state_dict_from_path(checkpoint_path) - if device == 'cuda': - accelerator = 'gpu' + if device == "cuda": + accelerator = "gpu" else: accelerator = None - trainer_logger = lightning.loggers.TensorBoardLogger( - save_dir=output_dir - ) + trainer_logger = lightning.loggers.TensorBoardLogger(save_dir=output_dir) trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger) # TODO: check for hasattr(model, test_step) and if so run test # below, [0] because validate returns list of dicts, length of no. of val loaders metric_vals = trainer.validate(model, dataloaders=val_loader)[0] for metric_name, metric_val in metric_vals.items(): - logger.info( - f'{metric_name}: {metric_val:0.5f}' - ) + logger.info(f"{metric_name}: {metric_val:0.5f}") # create a "DataFrame" with just one row which we will save as a csv; # the idea is to be able to concatenate csvs from multiple runs of eval @@ -161,15 +152,12 @@ def eval_parametric_umap_model( [ ("model_name", model_name), ("checkpoint_path", checkpoint_path), - ("spect_scaler_path", spect_scaler_path), ("dataset_path", dataset_path), ] ) # TODO: is this still necessary after switching to Lightning? Stop saying "average"? # order metrics by name to be extra sure they will be consistent across runs - row.update( - sorted([(k, v) for k, v in metric_vals.items()]) - ) + row.update(sorted([(k, v) for k, v in metric_vals.items()])) # pass index into dataframe, needed when using all scalar values (a single row) # throw away index below when saving to avoid extra column diff --git a/src/vak/learncurve/__init__.py b/src/vak/learncurve/__init__.py index c33192d7f..29ca75d15 100644 --- a/src/vak/learncurve/__init__.py +++ b/src/vak/learncurve/__init__.py @@ -1,7 +1,6 @@ from . import learncurve from .learncurve import learning_curve - __all__ = [ "learncurve", "learning_curve", diff --git a/src/vak/learncurve/curvefit.py b/src/vak/learncurve/curvefit.py index b04f52516..ed2b95be2 100644 --- a/src/vak/learncurve/curvefit.py +++ b/src/vak/learncurve/curvefit.py @@ -18,8 +18,8 @@ def residual_two_functions(params, x, y1, y1err, y2, y2err): c = params[2] beta = params[3] asymptote = params[4] - diff1 = (y1 - (asymptote + b * alpha ** x)) ** 2 / y1err - diff2 = (y2 - (asymptote + c * beta ** x)) ** 2 / y2err + diff1 = (y1 - (asymptote + b * alpha**x)) ** 2 / y1err + diff2 = (y2 - (asymptote + c * beta**x)) ** 2 / y2err return np.concatenate((diff1, diff2)) @@ -92,12 +92,17 @@ def fit_learning_curve( "Number of elements in train_set_size does not match number of columns in error_test" ) - fitfunc = lambda p, x: p[0] + p[1] * x - errfunc = lambda p, x, y, err: (y - fitfunc(p, x)) / err + def fitfunc(p, x): + return p[0] + p[1] * x + + def errfunc(p, x, y, err): + return (y - fitfunc(p, x)) / err logx = np.log10(train_set_size) - if error_train is None: # if we just have test error, fit with power function + if ( + error_train is None + ): # if we just have test error, fit with power function y = np.mean(error_test, axis=1) logy = np.log10(y) yerr = np.std(error_test, axis=1) @@ -132,16 +137,14 @@ def fit_learning_curve( logy2err = y2err / y # take mean of logy as best estimate of horizontal line estimate = np.average(logy2, weights=logy2err) - a = (10.0 ** estimate) / 2 + a = (10.0**estimate) / 2 return a, b, alpha elif error_train is not None and funcs == 2: y1 = np.mean(error_test, axis=1) y1err = np.std(error_test, axis=1) - logy1 = np.log10(y1) y2 = np.mean(error_train, axis=1) y2err = np.std(error_train, axis=1) - logy2 = np.log10(y2) if len(pinit) < 3: # if default pinit from function declaration # change instead to default pinit in next line pinit = [1.0, -1.0, 1.0, 1.0, 0.05] diff --git a/src/vak/learncurve/frame_classification.py b/src/vak/learncurve/frame_classification.py index 7e765d69e..82da2ce36 100644 --- a/src/vak/learncurve/frame_classification.py +++ b/src/vak/learncurve/frame_classification.py @@ -6,16 +6,11 @@ import pandas as pd -from .dirname import replicate_dirname, train_dur_dirname +from .. import common, datasets +from ..common.converters import expanded_user_path from ..eval.frame_classification import eval_frame_classification_model from ..train.frame_classification import train_frame_classification_model -from .. import ( - common, - datasets, -) -from ..common.converters import expanded_user_path -from ..common.paths import generate_results_dir_name_as_path - +from .dirname import replicate_dirname, train_dur_dirname logger = logging.getLogger(__name__) @@ -32,7 +27,7 @@ def learning_curve_for_frame_classification_model( val_transform_params: dict | None = None, val_dataset_params: dict | None = None, results_path: str | pathlib.Path = None, - post_tfm_kwargs: dict | None =None, + post_tfm_kwargs: dict | None = None, normalize_spectrograms: bool = True, shuffle: bool = True, val_step: int | None = None, @@ -146,7 +141,9 @@ def learning_curve_for_frame_classification_model( logger.info( f"Loading dataset from path: {dataset_path}", ) - metadata = datasets.frame_classification.Metadata.from_dataset_path(dataset_path) + metadata = datasets.frame_classification.Metadata.from_dataset_path( + dataset_path + ) dataset_csv_path = dataset_path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) @@ -174,25 +171,31 @@ def learning_curve_for_frame_classification_model( dataset_df = dataset_df[ (dataset_df.train_dur.notna()) & (dataset_df.replicate_num.notna()) ] - train_durs = sorted(dataset_df['train_dur'].unique()) - replicate_nums = [int(replicate_num) - for replicate_num in sorted(dataset_df['replicate_num'].unique())] + train_durs = sorted(dataset_df["train_dur"].unique()) + replicate_nums = [ + int(replicate_num) + for replicate_num in sorted(dataset_df["replicate_num"].unique()) + ] to_do = [] for train_dur in train_durs: for replicate_num in replicate_nums: to_do.append((train_dur, replicate_num)) # ---- main loop that creates "learning curve" --------------------------------------------------------------------- - logger.info(f"Starting training for learning curve.") + logger.info("Starting training for learning curve.") for train_dur, replicate_num in to_do: logger.info( f"Training model with training set of size: {train_dur}s, replicate number {replicate_num}.", ) - results_path_this_train_dur = results_path / train_dur_dirname(train_dur) + results_path_this_train_dur = results_path / train_dur_dirname( + train_dur + ) if not results_path_this_train_dur.exists(): results_path_this_train_dur.mkdir() - results_path_this_replicate = results_path_this_train_dur / replicate_dirname(replicate_num) + results_path_this_replicate = ( + results_path_this_train_dur / replicate_dirname(replicate_num) + ) results_path_this_replicate.mkdir() logger.info( @@ -225,12 +228,8 @@ def learning_curve_for_frame_classification_model( split=split, ) - logger.info( - f"Evaluating model from replicate {replicate_num} " - ) - results_model_root = ( - results_path_this_replicate.joinpath(model_name) - ) + logger.info(f"Evaluating model from replicate {replicate_num} ") + results_model_root = results_path_this_replicate.joinpath(model_name) ckpt_root = results_model_root.joinpath("checkpoints") ckpt_paths = sorted(ckpt_root.glob("*.pt")) if any(["max-val-acc" in str(ckpt_path) for ckpt_path in ckpt_paths]): @@ -250,20 +249,12 @@ def learning_curve_for_frame_classification_model( f"did not find a single checkpoint path, instead found:\n{ckpt_paths}" ) ckpt_path = ckpt_paths[0] - logger.info( - f"Using checkpoint: {ckpt_path}" - ) - labelmap_path = results_path_this_replicate.joinpath( - "labelmap.json" - ) - logger.info( - f"Using labelmap: {labelmap_path}" - ) + logger.info(f"Using checkpoint: {ckpt_path}") + labelmap_path = results_path_this_replicate.joinpath("labelmap.json") + logger.info(f"Using labelmap: {labelmap_path}") if normalize_spectrograms: - spect_scaler_path = ( - results_path_this_replicate.joinpath( - "StandardizeSpect" - ) + spect_scaler_path = results_path_this_replicate.joinpath( + "StandardizeSpect" ) logger.info( f"Using spect scaler to normalize: {spect_scaler_path}", @@ -295,9 +286,13 @@ def learning_curve_for_frame_classification_model( eval_dfs = [] for train_dur, replicate_num in to_do: - results_path_this_train_dur = results_path / train_dur_dirname(train_dur) - results_path_this_replicate = results_path_this_train_dur / replicate_dirname(replicate_num) - eval_csv_path = sorted(results_path_this_replicate.glob('eval*.csv')) + results_path_this_train_dur = results_path / train_dur_dirname( + train_dur + ) + results_path_this_replicate = ( + results_path_this_train_dur / replicate_dirname(replicate_num) + ) + eval_csv_path = sorted(results_path_this_replicate.glob("eval*.csv")) if not len(eval_csv_path) == 1: raise ValueError( "Did not find exactly one eval results csv file in replicate directory after running learncurve. " diff --git a/src/vak/learncurve/learncurve.py b/src/vak/learncurve/learncurve.py index 2200593a2..d4f0f08fb 100644 --- a/src/vak/learncurve/learncurve.py +++ b/src/vak/learncurve/learncurve.py @@ -4,12 +4,9 @@ import logging import pathlib -from .frame_classification import learning_curve_for_frame_classification_model -from .. import ( - models -) +from .. import models from ..common.converters import expanded_user_path - +from .frame_classification import learning_curve_for_frame_classification_model logger = logging.getLogger(__name__) @@ -26,7 +23,7 @@ def learning_curve( val_transform_params: dict | None = None, val_dataset_params: dict | None = None, results_path: str | pathlib.Path = None, - post_tfm_kwargs: dict | None =None, + post_tfm_kwargs: dict | None = None, normalize_spectrograms: bool = True, shuffle: bool = True, val_step: int | None = None, @@ -147,6 +144,4 @@ def learning_curve( device=device, ) else: - raise ValueError( - f"Model family not recognized: {model_family}" - ) + raise ValueError(f"Model family not recognized: {model_family}") diff --git a/src/vak/metrics/__init__.py b/src/vak/metrics/__init__.py index f175142f2..091fd2c73 100644 --- a/src/vak/metrics/__init__.py +++ b/src/vak/metrics/__init__.py @@ -1,2 +1,2 @@ -from .classification import * -from .distance import * +from .classification import * # noqa: F401, F403 +from .distance import * # noqa: F401, F403 diff --git a/src/vak/metrics/classification/__init__.py b/src/vak/metrics/classification/__init__.py index 9dd3ef661..91f0a3451 100644 --- a/src/vak/metrics/classification/__init__.py +++ b/src/vak/metrics/classification/__init__.py @@ -1,4 +1,4 @@ -from .classification import * +from .classification import Accuracy __all__ = [ diff --git a/src/vak/metrics/classification/classification.py b/src/vak/metrics/classification/classification.py index 28f57ad6c..5451c1e93 100644 --- a/src/vak/metrics/classification/classification.py +++ b/src/vak/metrics/classification/classification.py @@ -1,6 +1,5 @@ from . import functional as F - __all__ = [ "Accuracy", ] diff --git a/src/vak/metrics/distance/__init__.py b/src/vak/metrics/distance/__init__.py index 9b45ea9d4..6815e919f 100644 --- a/src/vak/metrics/distance/__init__.py +++ b/src/vak/metrics/distance/__init__.py @@ -1,3 +1,3 @@ -from .distance import * +from .distance import Levenshtein, SegmentErrorRate __all__ = ["Levenshtein", "SegmentErrorRate"] diff --git a/src/vak/metrics/distance/functional.py b/src/vak/metrics/distance/functional.py index cb3d7d647..e429e3a82 100644 --- a/src/vak/metrics/distance/functional.py +++ b/src/vak/metrics/distance/functional.py @@ -84,15 +84,15 @@ def segment_error_rate(y_pred, y_true): ------- Levenshtein distance / len(y_true) """ - if type(y_true) != str or type(y_pred) != str: + if not isinstance(y_true, str) or not isinstance(y_pred, str): raise TypeError("Both `y_true` and `y_pred` must be of type `str") # handle divide by zero edge cases if len(y_true) == 0 and len(y_pred) == 0: - return 0. + return 0.0 elif len(y_true) == 0 and len(y_pred) != 0: raise ValueError( - f'segment error rate is undefined when length of y_true is zero' + "segment error rate is undefined when length of y_true is zero" ) return levenshtein(y_pred, y_true) / len(y_true) diff --git a/src/vak/models/__init__.py b/src/vak/models/__init__.py index 81ce02a4c..5a8e9d031 100644 --- a/src/vak/models/__init__.py +++ b/src/vak/models/__init__.py @@ -1,18 +1,12 @@ -from . import ( - base, - decorator, - definition, - registry, -) +from . import base, decorator, definition, registry from .base import Model from .convencoder_umap import ConvEncoderUMAP -from .get import get from .ed_tcn import ED_TCN -from .teenytweetynet import TeenyTweetyNet -from .tweetynet import TweetyNet from .frame_classification_model import FrameClassificationModel +from .get import get from .parametric_umap_model import ParametricUMAPModel - +from .teenytweetynet import TeenyTweetyNet +from .tweetynet import TweetyNet __all__ = [ "base", diff --git a/src/vak/models/base.py b/src/vak/models/base.py index 163ad4275..2aa47022a 100644 --- a/src/vak/models/base.py +++ b/src/vak/models/base.py @@ -2,11 +2,12 @@ that other families of models should subclass. """ from __future__ import annotations + import inspect -from typing import Callable, ClassVar, Type +from typing import Callable, ClassVar -import torch import pytorch_lightning as lightning +import torch from .definition import ModelDefinition from .definition import validate as validate_definition @@ -31,13 +32,16 @@ class Model(lightning.LightningModule): using a ``vak.model.ModelDefinition``; see the documentation on that class for more detail. """ + definition: ClassVar[ModelDefinition] - def __init__(self, - network: torch.nn.Module | dict[str: torch.nn.Module] | None = None, - loss: torch.nn.Module | Callable | None = None, - optimizer: torch.optim.Optimizer | None = None, - metrics: dict[str: Type] | None = None): + def __init__( + self, + network: torch.nn.Module | dict | None = None, + loss: torch.nn.Module | Callable | None = None, + optimizer: torch.optim.Optimizer | None = None, + metrics: dict | None = None, + ): """Initializes an instance of a model, using its definition. Takes in instances of the attributes defined by the class variable @@ -73,25 +77,25 @@ def __init__(self, super().__init__() # check that we are a sub-class of some other class with required class variables - if not hasattr(self, 'definition'): + if not hasattr(self, "definition"): raise ValueError( - 'This model does not have a definition.' - 'Define a model by wrapping a class with the required class variables with ' - 'a ``vak.models`` decorator, e.g. ``vak.models.windowed_frame_classification_model``' + "This model does not have a definition." + "Define a model by wrapping a class with the required class variables with " + "a ``vak.models`` decorator, e.g. ``vak.models.windowed_frame_classification_model``" ) try: validate_definition(self.definition) except ModelDefinitionValidationError as err: raise ValueError( - 'Creating model instance failed because model definition is invalid.' + "Creating model instance failed because model definition is invalid." ) from err # ---- validate any instances that user passed in self.validate_init(network, loss, optimizer, metrics) if network is None: - net_kwargs = self.definition.default_config.get('network') + net_kwargs = self.definition.default_config.get("network") if isinstance(self.definition.network, dict): network = { network_name: network_class(**net_kwargs[network_name]) @@ -103,14 +107,14 @@ def __init__(self, if loss is None: if inspect.isclass(self.definition.loss): - loss_kwargs = self.definition.default_config.get('loss') + loss_kwargs = self.definition.default_config.get("loss") loss = self.definition.loss(**loss_kwargs) elif inspect.isfunction(self.definition.loss): loss = self.definition.loss self.loss = loss if optimizer is None: - optimizer_kwargs = self.definition.default_config.get('optimizer') + optimizer_kwargs = self.definition.default_config.get("optimizer") if isinstance(network, dict): params = [ param @@ -119,11 +123,13 @@ def __init__(self, ] else: params = network.parameters() - optimizer = self.definition.optimizer(params=params, **optimizer_kwargs) + optimizer = self.definition.optimizer( + params=params, **optimizer_kwargs + ) self.optimizer = optimizer if metrics is None: - metric_kwargs = self.definition.default_config.get('metrics') + metric_kwargs = self.definition.default_config.get("metrics") metrics = {} for metric_name, metric_class in self.definition.metrics.items(): metric_class_kwargs = metric_kwargs.get(metric_name, {}) @@ -131,11 +137,13 @@ def __init__(self, self.metrics = metrics @classmethod - def validate_init(cls, - network: torch.nn.Module | dict[str: torch.nn.Module] | None = None, - loss: torch.nn.Module | Callable | None = None, - optimizer: torch.optim.Optimizer | None = None, - metrics: dict[str: Type] | None = None): + def validate_init( + cls, + network: torch.nn.Module | dict | None = None, + loss: torch.nn.Module | Callable | None = None, + optimizer: torch.optim.Optimizer | None = None, + metrics: dict | None = None, + ): """Validate arguments to ``vak.models.base.Model.__init__``. Parameters @@ -171,38 +179,50 @@ def validate_init(cls, if inspect.isclass(cls.definition.network): if not isinstance(network, cls.definition.network): raise TypeError( - f'``network`` should be an instance of {cls.definition.network}' - f'but was of type {type(network)}' + f"``network`` should be an instance of {cls.definition.network}" + f"but was of type {type(network)}" ) elif isinstance(cls.definition.network, dict): if not isinstance(network, dict): raise TypeError( - 'Expected ``network`` to be a ``dict`` mapping network names ' - f'to ``torch.nn.Module`` instances, but type was {type(network)}' + "Expected ``network`` to be a ``dict`` mapping network names " + f"to ``torch.nn.Module`` instances, but type was {type(network)}" ) - expected_network_dict_keys = list(cls.definition.network.keys()) + expected_network_dict_keys = list( + cls.definition.network.keys() + ) network_dict_keys = list(network.keys()) - if not all([ - expected_network_dict_key in network_dict_keys - for expected_network_dict_key in expected_network_dict_keys - ]): - missing_keys = set(expected_network_dict_keys) - set(network_dict_keys) + if not all( + [ + expected_network_dict_key in network_dict_keys + for expected_network_dict_key in expected_network_dict_keys + ] + ): + missing_keys = set(expected_network_dict_keys) - set( + network_dict_keys + ) raise ValueError( - f'The following keys were missing from the ``network`` dict: {missing_keys}' + f"The following keys were missing from the ``network`` dict: {missing_keys}" ) if any( - [network_dict_key not in expected_network_dict_keys - for network_dict_key in network_dict_keys] + [ + network_dict_key not in expected_network_dict_keys + for network_dict_key in network_dict_keys + ] ): - extra_keys = set(network_dict_keys) - set(expected_network_dict_keys) + extra_keys = set(network_dict_keys) - set( + expected_network_dict_keys + ) raise ValueError( - f'The following keys in the ``network`` dict are not valid: {extra_keys}.' - f'Valid keys are: {expected_network_dict_keys}' + f"The following keys in the ``network`` dict are not valid: {extra_keys}." + f"Valid keys are: {expected_network_dict_keys}" ) for network_name, network_instance in network.items(): - if not isinstance(network_instance, cls.definition.network[network_name]): + if not isinstance( + network_instance, cls.definition.network[network_name] + ): raise TypeError( f"Network with name '{network_name}' in ``network`` dict " f"should be an instance of {cls.definition.network[network_name]}" @@ -210,38 +230,36 @@ def validate_init(cls, ) else: raise TypeError( - f'Invalid type for ``network``: {type(network)}' + f"Invalid type for ``network``: {type(network)}" ) if loss: if issubclass(cls.definition.loss, torch.nn.Module): if not isinstance(loss, cls.definition.loss): raise TypeError( - f'``loss`` should be an instance of {cls.definition.loss}' - f'but was of type {type(loss)}' + f"``loss`` should be an instance of {cls.definition.loss}" + f"but was of type {type(loss)}" ) elif callable(cls.definition.loss): if loss is not cls.definition.loss: raise ValueError( - f'``loss`` should be the following callable (probably a function): {cls.definition.loss}' + f"``loss`` should be the following callable (probably a function): {cls.definition.loss}" ) else: - raise TypeError( - f'Invalid type for ``loss``: {type(loss)}' - ) + raise TypeError(f"Invalid type for ``loss``: {type(loss)}") if optimizer: if not isinstance(optimizer, cls.definition.optimizer): raise TypeError( - f'``optimizer`` should be an instance of {cls.definition.optimizer}' - f'but was of type {type(optimizer)}' + f"``optimizer`` should be an instance of {cls.definition.optimizer}" + f"but was of type {type(optimizer)}" ) if metrics: if not isinstance(metrics, dict): raise TypeError( - '``metrics`` should be a ``dict`` mapping string metric names ' - f'to callable metrics, but type of ``metrics`` was {type(metrics)}' + "``metrics`` should be a ``dict`` mapping string metric names " + f"to callable metrics, but type of ``metrics`` was {type(metrics)}" ) for metric_name, metric_callable in metrics.items(): if metric_name not in cls.definition.metrics: @@ -251,7 +269,9 @@ def validate_init(cls, f"Valid metric names are: {', '.join(list(cls.definition.metrics.keys()))}" ) - if not isinstance(metric_callable, cls.definition.metrics[metric_name]): + if not isinstance( + metric_callable, cls.definition.metrics[metric_name] + ): raise TypeError( f"metric '{metric_name}' should be an instance of {cls.definition.metrics[metric_name]}" f"but was of type {type(metric_callable)}" @@ -286,7 +306,7 @@ def load_state_dict_from_path(self, ckpt_path): it does not return anything. """ ckpt = torch.load(ckpt_path) - self.load_state_dict(ckpt['state_dict']) + self.load_state_dict(ckpt["state_dict"]) @classmethod def attributes_from_config(cls, config: dict): @@ -323,7 +343,9 @@ class variables to ``Callable`` functions, used to measure performance of the model. """ - network_kwargs = config.get('network', cls.definition.default_config['network']) + network_kwargs = config.get( + "network", cls.definition.default_config["network"] + ) if inspect.isclass(cls.definition.network): network = cls.definition.network(**network_kwargs) elif isinstance(cls.definition.network, dict): @@ -341,16 +363,22 @@ class variables else: params = network.parameters() - optimizer_kwargs = config.get('optimizer', cls.definition.default_config['optimizer']) + optimizer_kwargs = config.get( + "optimizer", cls.definition.default_config["optimizer"] + ) optimizer = cls.definition.optimizer(params=params, **optimizer_kwargs) if inspect.isclass(cls.definition.loss): - loss_kwargs = config.get('loss', cls.definition.default_config['loss']) + loss_kwargs = config.get( + "loss", cls.definition.default_config["loss"] + ) loss = cls.definition.loss(**loss_kwargs) else: loss = cls.definition.loss - metrics_config = config.get('metrics', cls.definition.default_config['metrics']) + metrics_config = config.get( + "metrics", cls.definition.default_config["metrics"] + ) metrics = {} for metric_name, metric_class in cls.definition.metrics.items(): metrics_class_kwargs = metrics_config.get(metric_name, {}) @@ -375,4 +403,6 @@ def from_config(cls, config: dict): initialized using parameters from ``config``. """ network, loss, optimizer, metrics = cls.attributes_from_config(config) - return cls(network=network, loss=loss, optimizer=optimizer, metrics=metrics) + return cls( + network=network, loss=loss, optimizer=optimizer, metrics=metrics + ) diff --git a/src/vak/models/convencoder_umap.py b/src/vak/models/convencoder_umap.py index ace0e4193..edffdf11f 100644 --- a/src/vak/models/convencoder_umap.py +++ b/src/vak/models/convencoder_umap.py @@ -11,13 +11,9 @@ import torch -from .. import ( - metrics, - nets, - nn, -) -from .parametric_umap_model import ParametricUMAPModel +from .. import metrics, nets, nn from .decorator import model +from .parametric_umap_model import ParametricUMAPModel @model(family=ParametricUMAPModel) @@ -56,16 +52,18 @@ class ConvEncoderUMAP: https://direct.mit.edu/neco/article/33/11/2881/107068. """ - network = {'encoder': nets.ConvEncoder} + + network = {"encoder": nets.ConvEncoder} loss = nn.UmapLoss optimizer = torch.optim.AdamW - metrics = {'acc': metrics.Accuracy, - 'levenshtein': metrics.Levenshtein, - 'segment_error_rate': metrics.SegmentErrorRate, - 'loss': torch.nn.CrossEntropyLoss} + metrics = { + "acc": metrics.Accuracy, + "levenshtein": metrics.Levenshtein, + "segment_error_rate": metrics.SegmentErrorRate, + "loss": torch.nn.CrossEntropyLoss, + } default_config = { - 'optimizer': - {'lr': 1e-3}, + "optimizer": {"lr": 1e-3}, } @@ -75,5 +73,8 @@ def get_default_padding(shape): Rounds up to nearest tens place """ rounded_up = tuple(10 * math.ceil(x / 10) for x in shape) - padding = tuple(rounded_up_x - shape_x for (rounded_up_x, shape_x) in zip(rounded_up, shape)) + padding = tuple( + rounded_up_x - shape_x + for (rounded_up_x, shape_x) in zip(rounded_up, shape) + ) return padding diff --git a/src/vak/models/decorator.py b/src/vak/models/decorator.py index 88b8999a0..a0aa717fe 100644 --- a/src/vak/models/decorator.py +++ b/src/vak/models/decorator.py @@ -9,6 +9,7 @@ and have all model methods. """ from __future__ import annotations + from typing import Type from .base import Model @@ -22,6 +23,7 @@ class ModelDefinitionValidationError(Exception): Used by :func:`vak.models.decorator.model` decorator. """ + pass @@ -56,29 +58,30 @@ def model(family: Type[Model]): that will be used when making new instances of the model. """ + def _model(definition: Type): if not issubclass(family, Model): raise TypeError( - 'The ``family`` argument to the ``vak.models.model`` decorator' - 'should be a subclass of ``vak.models.base.Model``,' - f'but the type was: {type(family)}, ' - 'which was not recognized as a subclass ' - 'of ``vak.models.base.Model``.' + "The ``family`` argument to the ``vak.models.model`` decorator" + "should be a subclass of ``vak.models.base.Model``," + f"but the type was: {type(family)}, " + "which was not recognized as a subclass " + "of ``vak.models.base.Model``." ) try: validate_definition(definition) except ValueError as err: raise ModelDefinitionValidationError( - f'Validation failed for the following model definition:\n{definition}' + f"Validation failed for the following model definition:\n{definition}" ) from err except TypeError as err: raise ModelDefinitionValidationError( - f'Validation failed for the following model definition:\n{definition}' + f"Validation failed for the following model definition:\n{definition}" ) from err attributes = dict(family.__dict__) - attributes.update({'definition': definition}) + attributes.update({"definition": definition}) subclass_name = definition.__name__ subclass = type(subclass_name, (family,), attributes) subclass.__module__ = definition.__module__ diff --git a/src/vak/models/definition.py b/src/vak/models/definition.py index 28851a98a..14b5435de 100644 --- a/src/vak/models/definition.py +++ b/src/vak/models/definition.py @@ -2,23 +2,24 @@ of a neural network model; the abstraction of how models are declared with code in vak.""" from __future__ import annotations -import dataclasses +import dataclasses import inspect -from typing import Callable, Type, Union +from typing import Type, Union import torch - REQUIRED_MODEL_DEFINITION_CLASS_VARS = ( - 'network', - 'loss', - 'optimizer', - 'metrics', - 'default_config' + "network", + "loss", + "optimizer", + "metrics", + "default_config", ) -VALID_CONFIG_KEYS = REQUIRED_MODEL_DEFINITION_CLASS_VARS[:-1] # everything but 'default_config' +VALID_CONFIG_KEYS = REQUIRED_MODEL_DEFINITION_CLASS_VARS[ + :-1 +] # everything but 'default_config' @dataclasses.dataclass @@ -47,20 +48,21 @@ class ModelDefinition: sub-classes that represent model families. E.g., those classes will do: ``network = self.definition.network(**self.definition.default_config['network'])``. """ - network: Union[torch.nn.Module, dict[str: torch.nn.Module]] - loss: dict[Union[str: Callable, str: torch.nn.Module]] + + network: Union[torch.nn.Module, dict] + loss: dict optimizer: torch.optim.Optimizer - metrics: dict[str: Type] + metrics: dict default_config: dict # default that we set ``definition.default_config`` to, # if definition does not have that class variable declared DEFAULT_DEFAULT_CONFIG = { - 'network': {}, - 'loss': {}, - 'optimizer': {}, - 'metrics': {} + "network": {}, + "loss": {}, + "optimizer": {}, + "metrics": {}, } @@ -128,7 +130,7 @@ def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]: # need to set this default first # so we don't throw error when checking class variables # if user did not specify ``default_config`` - if not hasattr(definition, 'default_config'): + if not hasattr(definition, "default_config"): definition.default_config = DEFAULT_DEFAULT_CONFIG else: # if they **did** specify ``default_config``, @@ -144,20 +146,32 @@ def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]: key: val for key, val in vars(definition).items() # keep class vars; throw out __module__, __doc__, etc. - if not (key.startswith('__') and key.endswith('__')) + if not (key.startswith("__") and key.endswith("__")) } definition_class_var_names = list(definition_vars.keys()) - if not all([expected_class_var_name in definition_class_var_names - for expected_class_var_name in REQUIRED_MODEL_DEFINITION_CLASS_VARS]): - missing_var_name = set(REQUIRED_MODEL_DEFINITION_CLASS_VARS) - set(definition_class_var_names) + if not all( + [ + expected_class_var_name in definition_class_var_names + for expected_class_var_name in REQUIRED_MODEL_DEFINITION_CLASS_VARS + ] + ): + missing_var_name = set(REQUIRED_MODEL_DEFINITION_CLASS_VARS) - set( + definition_class_var_names + ) raise ValueError( f"Model definition is missing the following class variable(s): {missing_var_name}" ) # ---- check if there are any extra class variables - if any([modeldef_var_name not in REQUIRED_MODEL_DEFINITION_CLASS_VARS - for modeldef_var_name in definition_class_var_names]): - extra_var_name = set(definition_class_var_names) - set(REQUIRED_MODEL_DEFINITION_CLASS_VARS) + if any( + [ + modeldef_var_name not in REQUIRED_MODEL_DEFINITION_CLASS_VARS + for modeldef_var_name in definition_class_var_names + ] + ): + extra_var_name = set(definition_class_var_names) - set( + REQUIRED_MODEL_DEFINITION_CLASS_VARS + ) raise ValueError( f"Model definition has invalid class variable(s): {extra_var_name}." f"Valid class variables are: {REQUIRED_MODEL_DEFINITION_CLASS_VARS}" @@ -173,7 +187,7 @@ def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]: # we do validation "by hand" which is very verbose # ---- validate network - network_obj = getattr(definition, 'network') + network_obj = getattr(definition, "network") if inspect.isclass(network_obj): if not issubclass(network_obj, torch.nn.Module): raise TypeError( @@ -185,15 +199,15 @@ def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]: for network_dict_key, network_dict_val in network_obj.items(): if not isinstance(network_dict_key, str): raise TypeError( - 'A model definition with a ``network`` variable that is a dict ' - 'should have keys that are strings, ' - f'but the following key has type {type(network_dict_key)}: {network_dict_key}' + "A model definition with a ``network`` variable that is a dict " + "should have keys that are strings, " + f"but the following key has type {type(network_dict_key)}: {network_dict_key}" ) if not issubclass(network_dict_val, torch.nn.Module): raise TypeError( - 'A model definition with a ``network`` variable that is a dict ' - f'should have string keys mapping to values that are torch.nn.Module subclasses, ' - f'but the following value has type {type(network_dict_val)}: {network_dict_val}' + "A model definition with a ``network`` variable that is a dict " + f"should have string keys mapping to values that are torch.nn.Module subclasses, " + f"but the following value has type {type(network_dict_val)}: {network_dict_val}" ) else: raise TypeError( @@ -203,7 +217,7 @@ def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]: ) # ---- validate loss - loss_obj = getattr(definition, 'loss') + loss_obj = getattr(definition, "loss") # need complicated if-else here because issubclass throws an error if we don't pass it a class if inspect.isclass(loss_obj): if issubclass(loss_obj, torch.nn.Module): @@ -222,7 +236,7 @@ def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]: ) # ---- validate optimizer - optim_obj = getattr(definition, 'optimizer') + optim_obj = getattr(definition, "optimizer") if not issubclass(optim_obj, torch.optim.Optimizer): raise TypeError( "A model definition's 'optimizer' variable must be a subclass of torch.optim.Optimizer, " @@ -230,7 +244,7 @@ def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]: ) # ---- validate metrics - metrics_obj = getattr(definition, 'metrics') + metrics_obj = getattr(definition, "metrics") if not isinstance(metrics_obj, dict): raise TypeError( "A model definition's 'metrics' variable must be a dict mapping string names to callables, " @@ -242,41 +256,45 @@ def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]: f"A model definition's 'metrics' variable must be a dict mapping string names to callables, " f"but the following key has type {type(metrics_dict_key)}: {metrics_dict_key}" ) - if not (inspect.isclass(metrics_dict_val) and callable(metrics_dict_val)): + if not ( + inspect.isclass(metrics_dict_val) and callable(metrics_dict_val) + ): raise TypeError( "A model definition's 'metrics' variable must be a dict mapping " "string names to classes that define __call__ methods, " f"but the key '{metrics_dict_key}' maps to a value with type {type(metrics_dict_val)}, " - f'not recognized as callable.' + f"not recognized as callable." ) # ---- validate default config - default_config = getattr(definition, 'default_config') + default_config = getattr(definition, "default_config") if not all( - [config_key in VALID_CONFIG_KEYS - for config_key in default_config.keys()] + [ + config_key in VALID_CONFIG_KEYS + for config_key in default_config.keys() + ] ): - invalid_keys = [config_key - for config_key in default_config.keys() - if config_key not in VALID_CONFIG_KEYS - ] + invalid_keys = [ + config_key + for config_key in default_config.keys() + if config_key not in VALID_CONFIG_KEYS + ] raise ValueError( - f'Invalid keys in default_config: {invalid_keys}.' - f'Valid keys are: {VALID_CONFIG_KEYS}' + f"Invalid keys in default_config: {invalid_keys}." + f"Valid keys are: {VALID_CONFIG_KEYS}" ) # -------- validate 'network' config - network_config = default_config.get('network') + network_config = default_config.get("network") if network_config is None: if inspect.isclass(definition.network): # calling 'if issubclass(definition.network, torch.nn.Module)' # would raise an error when definition.network is a dict - definition.default_config['network'] = {} + definition.default_config["network"] = {} elif isinstance(definition.network, dict): - definition.default_config['network'] = { - network_name: {} - for network_name in definition.network.keys() + definition.default_config["network"] = { + network_name: {} for network_name in definition.network.keys() } elif len(network_config) > 0: if inspect.isclass(definition.network): @@ -287,33 +305,42 @@ def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]: definition.network.__init__ ).parameters.keys() ) - if any([ - network_kwarg not in network_init_params - for network_kwarg in network_config.keys() - ]): - invalid_keys = set(network_config.keys()) - set(network_init_params) + if any( + [ + network_kwarg not in network_init_params + for network_kwarg in network_config.keys() + ] + ): + invalid_keys = set(network_config.keys()) - set( + network_init_params + ) raise ValueError( - f'The following keyword arguments specified in the ``default_config`` ' - f'for ``network`` are invalid: {invalid_keys}.' - f'Valid arguments are: {network_init_params}' + f"The following keyword arguments specified in the ``default_config`` " + f"for ``network`` are invalid: {invalid_keys}." + f"Valid arguments are: {network_init_params}" ) elif isinstance(definition.network, dict): - if any([network_name not in definition.network.keys() - for network_name in network_config.keys()]): - - invalid_network_names = [network_name - for network_name in network_config.keys() - if network_name not in definition.network.keys()] + if any( + [ + network_name not in definition.network.keys() + for network_name in network_config.keys() + ] + ): + invalid_network_names = [ + network_name + for network_name in network_config.keys() + if network_name not in definition.network.keys() + ] raise ValueError( - "When model definition's ``network`` is a ``dict`` mapping string names to ``torch.nn.Module``s," - "the definition's ``default_config`` should have only those string names as keys." - f"The following keys in the default_config for network are invalid: {invalid_network_names}." - f"Valid keys are these network names: {definition.network.keys()}" - "Please rewrite ``default_config`` so keys of ``default_config['network']`` " - "are only those string names, " - "and the corresponding values for those keys are keyword arguments for the networks." + "When model definition's ``network`` is a ``dict`` mapping string names to ``torch.nn.Module``s," + "the definition's ``default_config`` should have only those string names as keys." + f"The following keys in the default_config for network are invalid: {invalid_network_names}." + f"Valid keys are these network names: {definition.network.keys()}" + "Please rewrite ``default_config`` so keys of ``default_config['network']`` " + "are only those string names, " + "and the corresponding values for those keys are keyword arguments for the networks." ) for network_name, network_kwargs in network_config.items(): network_init_params = list( @@ -321,21 +348,25 @@ def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]: definition.network[network_name].__init__ ).parameters.keys() ) - if any([ - network_kwarg not in network_init_params - for network_kwarg in network_kwargs.keys() - ]): - invalid_keys = set(network_config.keys()) - set(network_init_params) + if any( + [ + network_kwarg not in network_init_params + for network_kwarg in network_kwargs.keys() + ] + ): + invalid_keys = set(network_config.keys()) - set( + network_init_params + ) raise ValueError( - f'The following keyword arguments specified in the ``default_config`` ' - f'for ``network`` are invalid: {invalid_keys}.' - f'Valid arguments are: {network_init_params}' + f"The following keyword arguments specified in the ``default_config`` " + f"for ``network`` are invalid: {invalid_keys}." + f"Valid arguments are: {network_init_params}" ) # -------- validate 'loss' config - loss_config = default_config.get('loss') + loss_config = default_config.get("loss") if loss_config is None: - definition.default_config['loss'] = {} + definition.default_config["loss"] = {} elif len(loss_config) > 0: if inspect.isfunction(definition.loss): raise ValueError( @@ -343,56 +374,64 @@ def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]: "but loss is a function, not a class. Please only specify keyword arguments for classes." ) loss_init_params = list( - inspect.signature( - definition.loss.__init__ - ).parameters.keys() + inspect.signature(definition.loss.__init__).parameters.keys() ) if any( - [loss_kwarg not in loss_init_params - for loss_kwarg in loss_config.keys()] + [ + loss_kwarg not in loss_init_params + for loss_kwarg in loss_config.keys() + ] ): - invalid_loss_kwargs = set(loss_config.keys()) - set(loss_init_params) + invalid_loss_kwargs = set(loss_config.keys()) - set( + loss_init_params + ) raise ValueError( - f'The following keyword arguments specified in the ``default_config`` ' - f'for ``loss`` are invalid: {invalid_loss_kwargs}.' - f'Valid arguments are: {loss_init_params}' + f"The following keyword arguments specified in the ``default_config`` " + f"for ``loss`` are invalid: {invalid_loss_kwargs}." + f"Valid arguments are: {loss_init_params}" ) # -------- validate 'optimizer' config - optimizer_config = default_config.get('optimizer') + optimizer_config = default_config.get("optimizer") if optimizer_config is None: - definition.default_config['optimizer'] = {} + definition.default_config["optimizer"] = {} elif len(optimizer_config) > 0: optimizer_init_params = list( - inspect.signature( - definition.optimizer.__init__ - ).parameters.keys() + inspect.signature(definition.optimizer.__init__).parameters.keys() ) if any( - [optimizer_kwarg not in optimizer_init_params - for optimizer_kwarg in optimizer_config.keys()] + [ + optimizer_kwarg not in optimizer_init_params + for optimizer_kwarg in optimizer_config.keys() + ] ): - invalid_optimizer_kwargs = set(optimizer_config.keys()) - set(optimizer_init_params) + invalid_optimizer_kwargs = set(optimizer_config.keys()) - set( + optimizer_init_params + ) raise ValueError( - f'The following keyword arguments specified in the ``default_config`` ' - f'for ``optimizer`` are invalid: {invalid_optimizer_kwargs}.' - f'Valid arguments are: {optimizer_init_params}' + f"The following keyword arguments specified in the ``default_config`` " + f"for ``optimizer`` are invalid: {invalid_optimizer_kwargs}." + f"Valid arguments are: {optimizer_init_params}" ) # -------- validate 'metrics' config - metrics_config = default_config.get('metrics') + metrics_config = default_config.get("metrics") if metrics_config is None: - definition.default_config['metrics'] = {} + definition.default_config["metrics"] = {} elif len(metrics_config) > 0: if any( - [metric_name not in definition.metrics - for metric_name in metrics_config.keys()] + [ + metric_name not in definition.metrics + for metric_name in metrics_config.keys() + ] ): - invalid_metric_names = set(metrics_config.keys()) - set(definition.metrics.keys()) + invalid_metric_names = set(metrics_config.keys()) - set( + definition.metrics.keys() + ) raise ValueError( - f'The following metric names specified in the ``default_config`` ' - f'for ``metrics`` are invalid: {invalid_metric_names}.' - f'Valid metric names are: {definition.metrics.keys()}' + f"The following metric names specified in the ``default_config`` " + f"for ``metrics`` are invalid: {invalid_metric_names}." + f"Valid metric names are: {definition.metrics.keys()}" ) for metric_name, metric_class_config in metrics_config.items(): metric_class_init_params = list( @@ -401,10 +440,14 @@ def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]: ).parameters.keys() ) if any( - [metric_class_kwarg not in metric_class_init_params - for metric_class_kwarg in metric_class_config.keys()] + [ + metric_class_kwarg not in metric_class_init_params + for metric_class_kwarg in metric_class_config.keys() + ] ): - invalid_metric_class_kwargs = set(metric_class_config.keys()) - set(metric_class_init_params) + invalid_metric_class_kwargs = set( + metric_class_config.keys() + ) - set(metric_class_init_params) raise ValueError( f"The following keyword arguments specified in the ``default_config`` " f"for 'metrics' class {definition.metrics[metric_name]} are invalid: " diff --git a/src/vak/models/ed_tcn.py b/src/vak/models/ed_tcn.py index c3dd79e82..2cd30109b 100644 --- a/src/vak/models/ed_tcn.py +++ b/src/vak/models/ed_tcn.py @@ -4,12 +4,9 @@ import torch -from .. import ( - metrics, - nets -) -from .frame_classification_model import FrameClassificationModel +from .. import metrics, nets from .decorator import model +from .frame_classification_model import FrameClassificationModel @model(family=FrameClassificationModel) @@ -38,14 +35,14 @@ class ED_TCN: Temporal convolutional networks for action segmentation and detection. In proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 156-165). """ + network = nets.ED_TCN loss = torch.nn.CrossEntropyLoss optimizer = torch.optim.Adam - metrics = {'acc': metrics.Accuracy, - 'levenshtein': metrics.Levenshtein, - 'segment_error_rate': metrics.SegmentErrorRate, - 'loss': torch.nn.CrossEntropyLoss} - default_config = { - 'optimizer': - {'lr': 0.003} + metrics = { + "acc": metrics.Accuracy, + "levenshtein": metrics.Levenshtein, + "segment_error_rate": metrics.SegmentErrorRate, + "loss": torch.nn.CrossEntropyLoss, } + default_config = {"optimizer": {"lr": 0.003}} diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index daaa7d7bb..ff51ee222 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -5,16 +5,15 @@ from __future__ import annotations import logging -from typing import Callable, ClassVar, Mapping, Type +from typing import Callable, ClassVar, Mapping import torch +from .. import transforms +from ..common import labels from . import base from .definition import ModelDefinition from .registry import model_family -from .. import transforms -from ..common import labels - logger = logging.getLogger(__name__) @@ -86,16 +85,18 @@ class FrameClassificationModel(base.Model): to string labels inside of ``validation_step``, for computing edit distance. """ + definition: ClassVar[ModelDefinition] - def __init__(self, - labelmap: Mapping, - network: torch.nn.Module | dict[str: torch.nn.Module] | None = None, - loss: torch.nn.Module | Callable | None = None, - optimizer: torch.optim.Optimizer | None = None, - metrics: dict[str: Type] | None = None, - post_tfm: Callable | None = None, - ): + def __init__( + self, + labelmap: Mapping, + network: torch.nn.Module | dict | None = None, + loss: torch.nn.Module | Callable | None = None, + optimizer: torch.optim.Optimizer | None = None, + metrics: dict | None = None, + post_tfm: Callable | None = None, + ): """Initialize a new instance of a :class:`~vak.models.frame_classification_model.FrameClassificationModel`. @@ -124,26 +125,35 @@ def __init__(self, post_tfm : callable Post-processing transform applied to predictions. """ - super().__init__(network=network, loss=loss, - optimizer=optimizer, metrics=metrics) + super().__init__( + network=network, loss=loss, optimizer=optimizer, metrics=metrics + ) self.labelmap = labelmap # replace any multiple character labels in mapping # with single-character labels # so that we do not affect edit distance computation # see https://github.com/NickleDave/vak/issues/373 - labelmap_keys = [lbl for lbl in labelmap.keys() if lbl != 'unlabeled'] - if any([len(label) > 1 for label in labelmap_keys]): # only re-map if necessary + labelmap_keys = [lbl for lbl in labelmap.keys() if lbl != "unlabeled"] + if any( + [len(label) > 1 for label in labelmap_keys] + ): # only re-map if necessary # (to minimize chance of knock-on bugs) - logger.info("Detected that labelmap has keys with multiple characters:" - f"\n{labelmap_keys}\n" - "Re-mapping labelmap used with to_labels_eval transform, using " - "function vak.labels.multi_char_labels_to_single_char") - self.eval_labelmap = labels.multi_char_labels_to_single_char(labelmap) + logger.info( + "Detected that labelmap has keys with multiple characters:" + f"\n{labelmap_keys}\n" + "Re-mapping labelmap used with to_labels_eval transform, using " + "function vak.labels.multi_char_labels_to_single_char" + ) + self.eval_labelmap = labels.multi_char_labels_to_single_char( + labelmap + ) else: self.eval_labelmap = labelmap - self.to_labels_eval = transforms.frame_labels.ToLabels(self.eval_labelmap) + self.to_labels_eval = transforms.frame_labels.ToLabels( + self.eval_labelmap + ) self.post_tfm = post_tfm def configure_optimizers(self): @@ -192,7 +202,7 @@ def training_step(self, batch: tuple, batch_idx: int): x, y = batch[0], batch[1] out = self.network(x) loss = self.loss(out, y) - self.log(f'train_loss', loss) + self.log(f"train_loss: {loss}") return loss def validation_step(self, batch: tuple, batch_idx: int): @@ -265,19 +275,41 @@ def validation_step(self, batch: tuple, batch_idx: int): # TODO: figure out smarter way to do this for metric_name, metric_callable in self.metrics.items(): if metric_name == "loss": - self.log(f'val_{metric_name}', metric_callable(out, y), batch_size=1, on_step=True) + self.log( + f"val_{metric_name}", + metric_callable(out, y), + batch_size=1, + on_step=True, + ) elif metric_name == "acc": - self.log(f'val_{metric_name}', metric_callable(y_pred, y), batch_size=1) + self.log( + f"val_{metric_name}", + metric_callable(y_pred, y), + batch_size=1, + ) if self.post_tfm: - self.log(f'val_{metric_name}_tfm', - metric_callable(y_pred_tfm, y), - batch_size=1, on_step=True) - elif metric_name == "levenshtein" or metric_name == "segment_error_rate": - self.log(f'val_{metric_name}', metric_callable(y_pred_labels, y_labels), batch_size=1) + self.log( + f"val_{metric_name}_tfm", + metric_callable(y_pred_tfm, y), + batch_size=1, + on_step=True, + ) + elif ( + metric_name == "levenshtein" + or metric_name == "segment_error_rate" + ): + self.log( + f"val_{metric_name}", + metric_callable(y_pred_labels, y_labels), + batch_size=1, + ) if self.post_tfm: - self.log(f'val_{metric_name}_tfm', - metric_callable(y_pred_tfm_labels, y_labels), - batch_size=1, on_step=True) + self.log( + f"val_{metric_name}_tfm", + metric_callable(y_pred_tfm_labels, y_labels), + batch_size=1, + on_step=True, + ) def predict_step(self, batch: tuple, batch_idx: int): """Perform one prediction step. @@ -313,7 +345,9 @@ def predict_step(self, batch: tuple, batch_idx: int): return {source_path: y_pred} @classmethod - def from_config(cls, config: dict, labelmap: Mapping, post_tfm: Callable | None = None): + def from_config( + cls, config: dict, labelmap: Mapping, post_tfm: Callable | None = None + ): """Return an initialized model instance from a config ``dict`` Parameters @@ -333,10 +367,11 @@ def from_config(cls, config: dict, labelmap: Mapping, post_tfm: Callable | None initialized using parameters from ``config``. """ network, loss, optimizer, metrics = cls.attributes_from_config(config) - return cls(labelmap=labelmap, - network=network, - optimizer=optimizer, - loss=loss, - metrics=metrics, - post_tfm=post_tfm, - ) + return cls( + labelmap=labelmap, + network=network, + optimizer=optimizer, + loss=loss, + metrics=metrics, + post_tfm=post_tfm, + ) diff --git a/src/vak/models/get.py b/src/vak/models/get.py index a749cdbde..052a0e39b 100644 --- a/src/vak/models/get.py +++ b/src/vak/models/get.py @@ -8,12 +8,14 @@ from . import registry -def get(name: str, - config: dict, - input_shape: tuple[int, int, int], - num_classes: int | None = None, - labelmap: dict | None = None, - post_tfm: Callable | None = None): +def get( + name: str, + config: dict, + input_shape: tuple[int, int, int], + num_classes: int | None = None, + labelmap: dict | None = None, + post_tfm: Callable | None = None, +): """Get a model instance, given its name and a configuration as a :class:`dict`. @@ -56,20 +58,22 @@ def get(name: str, model_family = registry.MODEL_FAMILY_FROM_NAME[name] - if model_family == 'FrameClassificationModel': + if model_family == "FrameClassificationModel": # still need to special case model logic here net_init_params = list( inspect.signature( model_class.definition.network.__init__ ).parameters.keys() ) - if ('num_input_channels' in net_init_params) and ('num_freqbins' in net_init_params): + if ("num_input_channels" in net_init_params) and ( + "num_freqbins" in net_init_params + ): num_input_channels = input_shape[-3] num_freqbins = input_shape[-2] config["network"].update( num_classes=num_classes, num_input_channels=num_input_channels, - num_freqbins=num_freqbins + num_freqbins=num_freqbins, ) else: raise ValueError( @@ -77,14 +81,16 @@ def get(name: str, f"unable to determine network init arguments for model. Currently all models " f"in this family must have networks with parameters ``num_input_channels`` and ``num_freqbins``" ) - model = model_class.from_config(config=config, labelmap=labelmap, post_tfm=post_tfm) - elif model_family == 'ParametricUMAPModel': + model = model_class.from_config( + config=config, labelmap=labelmap, post_tfm=post_tfm + ) + elif model_family == "ParametricUMAPModel": encoder_init_params = list( inspect.signature( - model_class.definition.network['encoder'].__init__ + model_class.definition.network["encoder"].__init__ ).parameters.keys() ) - if ('input_shape' in encoder_init_params): + if "input_shape" in encoder_init_params: if "encoder" in config["network"]: config["network"]["encoder"].update(input_shape=input_shape) else: diff --git a/src/vak/models/parametric_umap_model.py b/src/vak/models/parametric_umap_model.py index 540c6ad58..6a9c87f45 100644 --- a/src/vak/models/parametric_umap_model.py +++ b/src/vak/models/parametric_umap_model.py @@ -37,6 +37,7 @@ class ParametricUMAPModel(base.Model): Neural Computation, 33(11), 2881-2907. https://direct.mit.edu/neco/article/33/11/2881/107068. """ + definition: ClassVar[ModelDefinition] def __init__( @@ -44,19 +45,22 @@ def __init__( network: dict | None = None, loss: torch.nn.Module | Callable | None = None, optimizer: torch.optim.Optimizer | None = None, - metrics: dict[str: Type] | None = None, + metrics: dict[str:Type] | None = None, ): - super().__init__(network=network, loss=loss, - optimizer=optimizer, metrics=metrics) - self.encoder = network['encoder'] - self.decoder = network.get('decoder', None) + super().__init__( + network=network, loss=loss, optimizer=optimizer, metrics=metrics + ) + self.encoder = network["encoder"] + self.decoder = network.get("decoder", None) def configure_optimizers(self): return self.optimizer def training_step(self, batch, batch_idx): (edges_to_exp, edges_from_exp) = batch - embedding_to, embedding_from = self.encoder(edges_to_exp), self.encoder(edges_from_exp) + embedding_to, embedding_from = self.encoder( + edges_to_exp + ), self.encoder(edges_from_exp) if self.decoder is not None: reconstruction = self.decoder(embedding_to) @@ -64,7 +68,9 @@ def training_step(self, batch, batch_idx): else: reconstruction = None before_encoding = None - loss_umap, loss_reconstruction, loss = self.loss(embedding_to, embedding_from, reconstruction, before_encoding) + loss_umap, loss_reconstruction, loss = self.loss( + embedding_to, embedding_from, reconstruction, before_encoding + ) self.log("train_umap_loss", loss_umap) if loss_reconstruction: self.log("train_reconstruction_loss", loss_reconstruction) @@ -74,7 +80,9 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): (edges_to_exp, edges_from_exp) = batch - embedding_to, embedding_from = self.encoder(edges_to_exp), self.encoder(edges_from_exp) + embedding_to, embedding_from = self.encoder( + edges_to_exp + ), self.encoder(edges_from_exp) if self.decoder is not None: reconstruction = self.decoder(embedding_to) @@ -82,16 +90,19 @@ def validation_step(self, batch, batch_idx): else: reconstruction = None before_encoding = None - loss_umap, loss_reconstruction, loss = self.loss(embedding_to, embedding_from, reconstruction, before_encoding) + loss_umap, loss_reconstruction, loss = self.loss( + embedding_to, embedding_from, reconstruction, before_encoding + ) self.log("val_umap_loss", loss_umap, on_step=True) if loss_reconstruction: - self.log("val_reconstruction_loss", loss_reconstruction, on_step=True) + self.log( + "val_reconstruction_loss", loss_reconstruction, on_step=True + ) # note if there's no ``loss_reconstruction``, then ``loss`` == ``loss_umap`` self.log("val_loss", loss, on_step=True) @classmethod - def from_config(cls, - config: dict): + def from_config(cls, config: dict): """Return an initialized model instance from a config ``dict`` Parameters @@ -107,10 +118,9 @@ def from_config(cls, initialized using parameters from ``config``. """ network, loss, optimizer, metrics = cls.attributes_from_config(config) - return cls(network=network, - optimizer=optimizer, - loss=loss, - metrics=metrics) + return cls( + network=network, optimizer=optimizer, loss=loss, metrics=metrics + ) class ParametricUMAPDatamodule(lightning.LightningDataModule): @@ -163,13 +173,28 @@ def __init__( self.model = ParametricUMAPModel(self.encoder, min_dist=self.min_dist) - def fit(self, trainer: lightning.Trainer, dataset_path: str | pathlib.Path, transform=None): + def fit( + self, + trainer: lightning.Trainer, + dataset_path: str | pathlib.Path, + transform=None, + ): from vak.datasets.parametric_umap import ParametricUMAPDataset - dataset = ParametricUMAPDataset.from_dataset_path(dataset_path, 'train', self.n_neighbors, self.metric, - self.random_state, self.num_epochs, transform) + + dataset = ParametricUMAPDataset.from_dataset_path( + dataset_path, + "train", + self.n_neighbors, + self.metric, + self.random_state, + self.num_epochs, + transform, + ) trainer.fit( model=self.model, - datamodule=ParametricUMAPDatamodule(dataset, self.batch_size, self.num_workers) + datamodule=ParametricUMAPDatamodule( + dataset, self.batch_size, self.num_workers + ), ) @torch.no_grad() diff --git a/src/vak/models/registry.py b/src/vak/models/registry.py index 5d626e627..a3d4cce88 100644 --- a/src/vak/models/registry.py +++ b/src/vak/models/registry.py @@ -46,9 +46,7 @@ def register_model(model_class: Type) -> None: with the existing :func:`vak.decorator.model`, that creates a model class from a model definition. """ - model_family_classes = list( - MODEL_FAMILY_REGISTRY.values() - ) + model_family_classes = list(MODEL_FAMILY_REGISTRY.values()) model_parent_class = inspect.getmro(model_class)[1] if model_parent_class not in model_family_classes: raise TypeError( @@ -72,7 +70,9 @@ def register_model(model_class: Type) -> None: f"Classes in the model family registry:\n{MODELS_BY_FAMILY_REGISTRY}" ) - MODELS_BY_FAMILY_REGISTRY[model_parent_class_name][model_name] = model_class + MODELS_BY_FAMILY_REGISTRY[model_parent_class_name][ + model_name + ] = model_class # need to return class after we register it or we replace it with None # when this function is used as a decorator return model_class @@ -80,19 +80,19 @@ def register_model(model_class: Type) -> None: def __getattr__(name: str) -> Any: """Module-level __getattr__ function that we use to dynamically determine models.""" - if name == 'MODEL_FAMILY_FROM_NAME': + if name == "MODEL_FAMILY_FROM_NAME": return { model_name: family_name for family_name, family_dict in MODELS_BY_FAMILY_REGISTRY.items() for model_name, model_class in family_dict.items() } - elif name == 'MODEL_CLASS_BY_NAME': + elif name == "MODEL_CLASS_BY_NAME": return { model_name: model_class for family_name, family_dict in MODELS_BY_FAMILY_REGISTRY.items() for model_name, model_class in family_dict.items() } - elif name == 'MODEL_NAMES': + elif name == "MODEL_NAMES": return list( { model_name: model_class diff --git a/src/vak/models/teenytweetynet.py b/src/vak/models/teenytweetynet.py index a6299fcc9..bdf839c57 100644 --- a/src/vak/models/teenytweetynet.py +++ b/src/vak/models/teenytweetynet.py @@ -2,24 +2,22 @@ """ import torch -from .. import metrics -from .. import nets - -from .frame_classification_model import FrameClassificationModel +from .. import metrics, nets from .decorator import model +from .frame_classification_model import FrameClassificationModel @model(family=FrameClassificationModel) class TeenyTweetyNet: """lightweight version of ``vak.models.TweetyNet`` used by ``vak`` unit tests""" + network = nets.TeenyTweetyNet loss = torch.nn.CrossEntropyLoss optimizer = torch.optim.Adam - metrics = {'acc': metrics.Accuracy, - 'levenshtein': metrics.Levenshtein, - 'segment_error_rate': metrics.SegmentErrorRate, - 'loss': torch.nn.CrossEntropyLoss} - default_config = { - 'optimizer': - {'lr': 0.003} + metrics = { + "acc": metrics.Accuracy, + "levenshtein": metrics.Levenshtein, + "segment_error_rate": metrics.SegmentErrorRate, + "loss": torch.nn.CrossEntropyLoss, } + default_config = {"optimizer": {"lr": 0.003}} diff --git a/src/vak/models/tweetynet.py b/src/vak/models/tweetynet.py index 1cd4aa4e7..466492011 100644 --- a/src/vak/models/tweetynet.py +++ b/src/vak/models/tweetynet.py @@ -10,12 +10,9 @@ import torch -from .. import ( - metrics, - nets -) -from .frame_classification_model import FrameClassificationModel +from .. import metrics, nets from .decorator import model +from .frame_classification_model import FrameClassificationModel @model(family=FrameClassificationModel) @@ -56,14 +53,14 @@ class TweetyNet: Paper: https://elifesciences.org/articles/63853 Code: https://github.com/yardencsGitHub/tweetynet """ + network = nets.TweetyNet loss = torch.nn.CrossEntropyLoss optimizer = torch.optim.Adam - metrics = {'acc': metrics.Accuracy, - 'levenshtein': metrics.Levenshtein, - 'segment_error_rate': metrics.SegmentErrorRate, - 'loss': torch.nn.CrossEntropyLoss} - default_config = { - 'optimizer': - {'lr': 0.003} + metrics = { + "acc": metrics.Accuracy, + "levenshtein": metrics.Levenshtein, + "segment_error_rate": metrics.SegmentErrorRate, + "loss": torch.nn.CrossEntropyLoss, } + default_config = {"optimizer": {"lr": 0.003}} diff --git a/src/vak/nets/__init__.py b/src/vak/nets/__init__.py index ab0202af2..8602ec7dc 100644 --- a/src/vak/nets/__init__.py +++ b/src/vak/nets/__init__.py @@ -1,23 +1,16 @@ -from . import ( - conv_encoder, - ed_tcn, - teenytweetynet, - tweetynet, -) - +from . import conv_encoder, ed_tcn, teenytweetynet, tweetynet from .conv_encoder import ConvEncoder from .ed_tcn import ED_TCN from .teenytweetynet import TeenyTweetyNet from .tweetynet import TweetyNet - __all__ = [ - 'conv_encoder', - 'ConvEncoder', - 'ed_tcn', - 'ED_TCN', - 'teenytweetynet', - 'TeenyTweetyNet', - 'tweetynet', - 'TweetyNet', + "conv_encoder", + "ConvEncoder", + "ed_tcn", + "ED_TCN", + "teenytweetynet", + "TeenyTweetyNet", + "tweetynet", + "TweetyNet", ] diff --git a/src/vak/nets/conv_encoder.py b/src/vak/nets/conv_encoder.py index 16527b0cd..6fd828f2f 100644 --- a/src/vak/nets/conv_encoder.py +++ b/src/vak/nets/conv_encoder.py @@ -8,15 +8,18 @@ class ConvEncoder(nn.Module): """Convolutional encoder, used by Parametric UMAP model. """ - def __init__(self, - input_shape: tuple[int], - conv1_filters: int = 64, - conv2_filters: int = 128, - conv_kernel_size: int = 3, - conv_stride: int = 2, - conv_padding: int = 1, - n_features_linear: int = 512, - n_components: int = 2): + + def __init__( + self, + input_shape: tuple[int], + conv1_filters: int = 64, + conv2_filters: int = 128, + conv_kernel_size: int = 3, + conv_stride: int = 2, + conv_padding: int = 1, + n_features_linear: int = 512, + n_components: int = 2, + ): """Initialize a ConvEncoder instance. Parameters @@ -51,14 +54,20 @@ def __init__(self, self.conv = nn.Sequential( nn.Conv2d( - in_channels=self.num_input_channels, out_channels=conv1_filters, - kernel_size=conv_kernel_size, stride=conv_stride, padding=conv_padding, + in_channels=self.num_input_channels, + out_channels=conv1_filters, + kernel_size=conv_kernel_size, + stride=conv_stride, + padding=conv_padding, ), nn.Conv2d( - in_channels=conv1_filters, out_channels=conv2_filters, - kernel_size=conv_kernel_size, stride=conv_stride, padding=conv_padding, + in_channels=conv1_filters, + out_channels=conv2_filters, + kernel_size=conv_kernel_size, + stride=conv_stride, + padding=conv_padding, ), - nn.Flatten() + nn.Flatten(), ) mock_input = torch.rand((1, *input_shape)) mock_conv_out = self.conv(mock_input) @@ -69,7 +78,7 @@ def __init__(self, nn.ReLU(), nn.Linear(n_features_linear, n_features_linear), nn.ReLU(), - nn.Linear(n_features_linear, n_components) + nn.Linear(n_features_linear, n_components), ) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/src/vak/nets/ed_tcn.py b/src/vak/nets/ed_tcn.py index 3b6a2426a..dc8d7f9d1 100644 --- a/src/vak/nets/ed_tcn.py +++ b/src/vak/nets/ed_tcn.py @@ -1,4 +1,5 @@ import torch + from ..nn.modules import Conv2dTF, NormReLU @@ -18,81 +19,112 @@ class ED_TCN(torch.nn.Module): Temporal convolutional networks for action segmentation and detection. In proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 156-165). """ - def __init__(self, - num_classes, - num_input_channels=1, - num_freqbins=256, - padding='SAME', - conv2d_1_filters=32, - conv2d_1_kernel_size=(5, 5), - conv2d_2_filters=64, - conv2d_2_kernel_size=(5, 5), - pool1_size=(8, 1), - pool1_stride=(8, 1), - pool2_size=(8, 1), - pool2_stride=(8, 1), - conv1d_1_filters=64, - conv1d_2_filters=96, - conv1d_kernel_size=25, - ): + + def __init__( + self, + num_classes, + num_input_channels=1, + num_freqbins=256, + padding="SAME", + conv2d_1_filters=32, + conv2d_1_kernel_size=(5, 5), + conv2d_2_filters=64, + conv2d_2_kernel_size=(5, 5), + pool1_size=(8, 1), + pool1_stride=(8, 1), + pool2_size=(8, 1), + pool2_stride=(8, 1), + conv1d_1_filters=64, + conv1d_2_filters=96, + conv1d_kernel_size=25, + ): super().__init__() self.num_classes = num_classes self.num_input_channels = num_input_channels self.num_freqbins = num_freqbins self.cnn = torch.nn.Sequential( - Conv2dTF(in_channels=self.num_input_channels, - out_channels=conv2d_1_filters, - kernel_size=conv2d_1_kernel_size, - padding=padding - ), + Conv2dTF( + in_channels=self.num_input_channels, + out_channels=conv2d_1_filters, + kernel_size=conv2d_1_kernel_size, + padding=padding, + ), torch.nn.ReLU(inplace=True), - torch.nn.MaxPool2d(kernel_size=pool1_size, - stride=pool1_stride), - Conv2dTF(in_channels=conv2d_1_filters, - out_channels=conv2d_2_filters, - kernel_size=conv2d_2_kernel_size, - padding=padding, - ), + torch.nn.MaxPool2d(kernel_size=pool1_size, stride=pool1_stride), + Conv2dTF( + in_channels=conv2d_1_filters, + out_channels=conv2d_2_filters, + kernel_size=conv2d_2_kernel_size, + padding=padding, + ), torch.nn.ReLU(inplace=True), - torch.nn.MaxPool2d(kernel_size=pool2_size, - stride=pool2_stride), + torch.nn.MaxPool2d(kernel_size=pool2_size, stride=pool2_stride), ) # determine number of features in output after stacking channels # we use the same number of features for hidden states # note self.num_hidden is also used to reshape output of cnn in self.forward method # determine number of features in output after stacking channels - N_DUMMY_TIMEBINS = 256 # some not-small number. This dimension doesn't matter here - batch_shape = (1, self.num_input_channels, self.num_freqbins, N_DUMMY_TIMEBINS) + N_DUMMY_TIMEBINS = ( + 256 # some not-small number. This dimension doesn't matter here + ) + batch_shape = ( + 1, + self.num_input_channels, + self.num_freqbins, + N_DUMMY_TIMEBINS, + ) tmp_tensor = torch.rand(batch_shape) tmp_out = self.cnn(tmp_tensor) channels_out, freqbins_out = tmp_out.shape[1], tmp_out.shape[2] self.n_cnn_features_out = channels_out * freqbins_out self.encoder = torch.nn.Sequential( - torch.nn.Conv1d(self.n_cnn_features_out, conv1d_1_filters, conv1d_kernel_size, padding='same'), + torch.nn.Conv1d( + self.n_cnn_features_out, + conv1d_1_filters, + conv1d_kernel_size, + padding="same", + ), torch.nn.Dropout1d(p=0.3), NormReLU(), torch.nn.MaxPool1d(kernel_size=2), - torch.nn.Conv1d(conv1d_1_filters, conv1d_2_filters, conv1d_kernel_size, padding='same'), + torch.nn.Conv1d( + conv1d_1_filters, + conv1d_2_filters, + conv1d_kernel_size, + padding="same", + ), torch.nn.Dropout1d(0.3), NormReLU(), - torch.nn.MaxPool1d(kernel_size=2) + torch.nn.MaxPool1d(kernel_size=2), ) self.decoder = torch.nn.Sequential( torch.nn.Upsample(scale_factor=2), - torch.nn.Conv1d(conv1d_2_filters, conv1d_2_filters, conv1d_kernel_size, padding='same'), + torch.nn.Conv1d( + conv1d_2_filters, + conv1d_2_filters, + conv1d_kernel_size, + padding="same", + ), torch.nn.Dropout1d(p=0.3), NormReLU(), torch.nn.Upsample(scale_factor=2), - torch.nn.Conv1d(conv1d_2_filters, conv1d_1_filters, conv1d_kernel_size, padding='same'), + torch.nn.Conv1d( + conv1d_2_filters, + conv1d_1_filters, + conv1d_kernel_size, + padding="same", + ), torch.nn.Dropout1d(0.3), NormReLU(), ) - self.fc = torch.nn.Linear(in_features=conv1d_1_filters, out_features=self.num_classes) + self.fc = torch.nn.Linear( + in_features=conv1d_1_filters, out_features=self.num_classes + ) def forward(self, x): x = self.cnn(x) @@ -100,7 +132,11 @@ def forward(self, x): x = x.view(x.shape[0], self.n_cnn_features_out, -1) x = self.encoder(x) x = self.decoder(x) - x = x.permute(0, 2, 1) # so that we can project features down on to number of classes + x = x.permute( + 0, 2, 1 + ) # so that we can project features down on to number of classes x = self.fc(x) - x = x.permute(0, 2, 1) # switch back to (batch, classes, time) for loss function + x = x.permute( + 0, 2, 1 + ) # switch back to (batch, classes, time) for loss function return x diff --git a/src/vak/nets/teenytweetynet.py b/src/vak/nets/teenytweetynet.py index 4e91e3f4f..dafd694b9 100644 --- a/src/vak/nets/teenytweetynet.py +++ b/src/vak/nets/teenytweetynet.py @@ -10,6 +10,7 @@ class TeenyTweetyNet(nn.Module): ----- This is the network used by ``vak.models.TeenyTweetyNetModel``. """ + def __init__( self, num_classes, @@ -85,8 +86,15 @@ def __init__( # determine number of features in output after stacking channels # we use the same number of features for hidden states # note self.num_hidden is also used to reshape output of cnn in self.forward method - N_DUMMY_TIMEBINS = 256 # some not-small number. This dimension doesn't matter here - batch_shape = (1, self.num_input_channels, self.num_freqbins, N_DUMMY_TIMEBINS) + N_DUMMY_TIMEBINS = ( + 256 # some not-small number. This dimension doesn't matter here + ) + batch_shape = ( + 1, + self.num_input_channels, + self.num_freqbins, + N_DUMMY_TIMEBINS, + ) tmp_tensor = torch.rand(batch_shape) tmp_out = self.cnn(tmp_tensor) channels_out, freqbins_out = tmp_out.shape[1], tmp_out.shape[2] diff --git a/src/vak/nets/tweetynet.py b/src/vak/nets/tweetynet.py index b6d62f19a..9a5136a02 100644 --- a/src/vak/nets/tweetynet.py +++ b/src/vak/nets/tweetynet.py @@ -3,6 +3,7 @@ import torch from torch import nn + from ..nn.modules import Conv2dTF @@ -47,24 +48,26 @@ class TweetyNet(nn.Module): ----- This is the network used by ``vak.models.TweetyNetModel``. """ - def __init__(self, - num_classes, - num_input_channels=1, - num_freqbins=256, - padding='SAME', - conv1_filters=32, - conv1_kernel_size=(5, 5), - conv2_filters=64, - conv2_kernel_size=(5, 5), - pool1_size=(8, 1), - pool1_stride=(8, 1), - pool2_size=(8, 1), - pool2_stride=(8, 1), - hidden_size=None, - rnn_dropout=0., - num_layers=1, - bidirectional=True, - ): + + def __init__( + self, + num_classes, + num_input_channels=1, + num_freqbins=256, + padding="SAME", + conv1_filters=32, + conv1_kernel_size=(5, 5), + conv2_filters=64, + conv2_kernel_size=(5, 5), + pool1_size=(8, 1), + pool1_stride=(8, 1), + pool2_size=(8, 1), + pool2_stride=(8, 1), + hidden_size=None, + rnn_dropout=0.0, + num_layers=1, + bidirectional=True, + ): """initialize TweetyNet model Parameters @@ -112,29 +115,36 @@ def __init__(self, self.num_freqbins = num_freqbins self.cnn = nn.Sequential( - Conv2dTF(in_channels=self.num_input_channels, - out_channels=conv1_filters, - kernel_size=conv1_kernel_size, - padding=padding - ), + Conv2dTF( + in_channels=self.num_input_channels, + out_channels=conv1_filters, + kernel_size=conv1_kernel_size, + padding=padding, + ), nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=pool1_size, - stride=pool1_stride), - Conv2dTF(in_channels=conv1_filters, - out_channels=conv2_filters, - kernel_size=conv2_kernel_size, - padding=padding, - ), + nn.MaxPool2d(kernel_size=pool1_size, stride=pool1_stride), + Conv2dTF( + in_channels=conv1_filters, + out_channels=conv2_filters, + kernel_size=conv2_kernel_size, + padding=padding, + ), nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=pool2_size, - stride=pool2_stride), + nn.MaxPool2d(kernel_size=pool2_size, stride=pool2_stride), ) # determine number of features in output after stacking channels # we use the same number of features for hidden states # note self.num_hidden is also used to reshape output of cnn in self.forward method - N_DUMMY_TIMEBINS = 256 # some not-small number. This dimension doesn't matter here - batch_shape = (1, self.num_input_channels, self.num_freqbins, N_DUMMY_TIMEBINS) + N_DUMMY_TIMEBINS = ( + 256 # some not-small number. This dimension doesn't matter here + ) + batch_shape = ( + 1, + self.num_input_channels, + self.num_freqbins, + N_DUMMY_TIMEBINS, + ) tmp_tensor = torch.rand(batch_shape) tmp_out = self.cnn(tmp_tensor) channels_out, freqbins_out = tmp_out.shape[1], tmp_out.shape[2] @@ -145,15 +155,19 @@ def __init__(self, else: self.hidden_size = hidden_size - self.rnn = nn.LSTM(input_size=self.rnn_input_size, - hidden_size=self.hidden_size, - num_layers=num_layers, - dropout=rnn_dropout, - bidirectional=bidirectional) + self.rnn = nn.LSTM( + input_size=self.rnn_input_size, + hidden_size=self.hidden_size, + num_layers=num_layers, + dropout=rnn_dropout, + bidirectional=bidirectional, + ) # for self.fc, in_features = hidden_size * 2 because LSTM is bidirectional # so we get hidden forward + hidden backward as output - self.fc = nn.Linear(in_features=self.hidden_size * 2, out_features=num_classes) + self.fc = nn.Linear( + in_features=self.hidden_size * 2, out_features=num_classes + ) def forward(self, x): features = self.cnn(x) diff --git a/src/vak/nn/__init__.py b/src/vak/nn/__init__.py index 00f6ea247..b3d46b8d0 100644 --- a/src/vak/nn/__init__.py +++ b/src/vak/nn/__init__.py @@ -1,3 +1,3 @@ -from .loss import * -from .modules import * -from . import functional +from . import functional # noqa: F401, F403 +from .loss import * # noqa: F401, F403 +from .modules import * # noqa: F401, F403 diff --git a/src/vak/nn/functional.py b/src/vak/nn/functional.py index 57db42b1f..3cf9b8567 100644 --- a/src/vak/nn/functional.py +++ b/src/vak/nn/functional.py @@ -2,7 +2,6 @@ import torch - __all__ = ["one_hot"] @@ -40,12 +39,16 @@ def one_hot( """ if not isinstance(labels, torch.Tensor): raise TypeError( - "Input labels type is not a torch.Tensor. Got {}".format(type(labels)) + "Input labels type is not a torch.Tensor. Got {}".format( + type(labels) + ) ) if not labels.dtype == torch.int64: raise ValueError( - "labels must be of the same dtype torch.int64. Got: {}".format(labels.dtype) + "labels must be of the same dtype torch.int64. Got: {}".format( + labels.dtype + ) ) if num_classes < 1: diff --git a/src/vak/nn/loss/__init__.py b/src/vak/nn/loss/__init__.py index 5e6951100..0f659b873 100644 --- a/src/vak/nn/loss/__init__.py +++ b/src/vak/nn/loss/__init__.py @@ -1,2 +1,10 @@ -from .dice import * -from .umap import umap_loss, UmapLoss +from .dice import DiceLoss, dice_loss +from .umap import UmapLoss, umap_loss + + +__all__ = [ + "DiceLoss", + "dice_loss", + "UmapLoss", + "umap_loss", +] diff --git a/src/vak/nn/loss/dice.py b/src/vak/nn/loss/dice.py index 3e35b0386..db334da4f 100644 --- a/src/vak/nn/loss/dice.py +++ b/src/vak/nn/loss/dice.py @@ -44,7 +44,9 @@ def dice_loss( >>> output.backward() """ if not isinstance(input, torch.Tensor): - raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(input))) + raise TypeError( + "Input type is not a torch.Tensor. Got {}".format(type(input)) + ) if not len(input.shape) == 3: raise ValueError( @@ -68,7 +70,10 @@ def dice_loss( # create the labels one hot tensor target_one_hot: torch.Tensor = vakF.one_hot( - target, num_classes=input.shape[1], device=input.device, dtype=input.dtype + target, + num_classes=input.shape[1], + device=input.device, + dtype=input.dtype, ) # compute the actual dice score @@ -116,5 +121,7 @@ def __init__(self, eps: float = 1e-8) -> None: super(DiceLoss, self).__init__() self.eps: float = eps - def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + def forward( + self, input: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: return dice_loss(input, target, self.eps) diff --git a/src/vak/nn/loss/umap.py b/src/vak/nn/loss/umap.py index 7ceaba829..794a4d26e 100644 --- a/src/vak/nn/loss/umap.py +++ b/src/vak/nn/loss/umap.py @@ -4,14 +4,17 @@ import warnings import torch -from torch.nn.functional import mse_loss +# isort: off # Ignore warnings from Numba deprecation: # https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit # Numba is required by UMAP. from numba.core.errors import NumbaDeprecationWarning -warnings.simplefilter('ignore', category=NumbaDeprecationWarning) -from umap.umap_ import find_ab_params +from torch.nn.functional import mse_loss + +warnings.simplefilter("ignore", category=NumbaDeprecationWarning) +from umap.umap_ import find_ab_params # noqa : E402 +# isort: on def convert_distance_to_probability(distances, a=1.0, b=1.0): @@ -30,7 +33,10 @@ def convert_distance_to_probability(distances, a=1.0, b=1.0): def compute_cross_entropy( - probabilities_graph, probabilities_distance, EPS=1e-4, repulsion_strength=1.0 + probabilities_graph, + probabilities_distance, + EPS=1e-4, + repulsion_strength=1.0, ): """Computes cross entropy as used for UMAP cost function""" # cross entropy @@ -38,7 +44,12 @@ def compute_cross_entropy( probabilities_distance ) repulsion_term = ( - -(1.0 - probabilities_graph) * (torch.nn.functional.logsigmoid(probabilities_distance) - probabilities_distance) * repulsion_strength + -(1.0 - probabilities_graph) + * ( + torch.nn.functional.logsigmoid(probabilities_distance) + - probabilities_distance + ) + * repulsion_strength ) # balance the expected losses between attraction and repulsion @@ -46,8 +57,13 @@ def compute_cross_entropy( return attraction_term, repulsion_term, CE -def umap_loss(embedding_to: torch.Tensor, embedding_from: torch.Tensor, - a, b, negative_sample_rate: int = 5): +def umap_loss( + embedding_to: torch.Tensor, + embedding_from: torch.Tensor, + a, + b, + negative_sample_rate: int = 5, +): """UMAP loss function Converts distances to probabilities, @@ -57,11 +73,14 @@ def umap_loss(embedding_to: torch.Tensor, embedding_from: torch.Tensor, embedding_neg_to = embedding_to.repeat(negative_sample_rate, 1) repeat_neg = embedding_from.repeat(negative_sample_rate, 1) embedding_neg_from = repeat_neg[torch.randperm(repeat_neg.shape[0])] - distance_embedding = torch.cat(( - (embedding_to - embedding_from).norm(dim=1), - (embedding_neg_to - embedding_neg_from).norm(dim=1) - # ``to`` method in next line to avoid error `Expected all tensors to be on the same device` - ), dim=0).to(embedding_to.device) + distance_embedding = torch.cat( + ( + (embedding_to - embedding_from).norm(dim=1), + (embedding_neg_to - embedding_neg_from).norm(dim=1) + # ``to`` method in next line to avoid error `Expected all tensors to be on the same device` + ), + dim=0, + ).to(embedding_to.device) # convert probabilities to distances probabilities_distance = convert_distance_to_probability( @@ -70,8 +89,12 @@ def umap_loss(embedding_to: torch.Tensor, embedding_from: torch.Tensor, # set true probabilities based on negative sampling batch_size = embedding_to.shape[0] probabilities_graph = torch.cat( - (torch.ones(batch_size), torch.zeros(batch_size * negative_sample_rate)), dim=0, - # ``to`` method in next line to avoid error `Expected all tensors to be on the same device` + ( + torch.ones(batch_size), + torch.zeros(batch_size * negative_sample_rate), + ), + dim=0, + # ``to`` method in next line to avoid error `Expected all tensors to be on the same device` ).to(embedding_to.device) # compute cross entropy @@ -85,22 +108,34 @@ def umap_loss(embedding_to: torch.Tensor, embedding_from: torch.Tensor, class UmapLoss(torch.nn.Module): """""" - def __init__(self, - spread: float = 1.0, - min_dist: float = 0.1, - negative_sample_rate: int = 5, - beta: float = 1.0, - ): + + def __init__( + self, + spread: float = 1.0, + min_dist: float = 0.1, + negative_sample_rate: int = 5, + beta: float = 1.0, + ): super().__init__() self.min_dist = min_dist self.a, self.b = find_ab_params(spread, min_dist) self.negative_sample_rate = negative_sample_rate self.beta = beta - - def forward(self, embedding_to: torch.Tensor, embedding_from: torch.Tensor, - reconstruction: torch.Tensor | None = None, before_encoding: torch.Tensor | None = None): - loss_umap = umap_loss(embedding_to, embedding_from, self.a, self.b, self.negative_sample_rate) + def forward( + self, + embedding_to: torch.Tensor, + embedding_from: torch.Tensor, + reconstruction: torch.Tensor | None = None, + before_encoding: torch.Tensor | None = None, + ): + loss_umap = umap_loss( + embedding_to, + embedding_from, + self.a, + self.b, + self.negative_sample_rate, + ) if reconstruction is not None: loss_reconstruction = mse_loss(reconstruction, before_encoding) loss = loss_umap + self.beta * loss_reconstruction diff --git a/src/vak/nn/modules/__init__.py b/src/vak/nn/modules/__init__.py index 9120afbce..b6cf86f99 100644 --- a/src/vak/nn/modules/__init__.py +++ b/src/vak/nn/modules/__init__.py @@ -2,6 +2,6 @@ from .conv import Conv2dTF __all__ = [ - 'Conv2dTF', - 'NormReLU', + "Conv2dTF", + "NormReLU", ] diff --git a/src/vak/nn/modules/activation.py b/src/vak/nn/modules/activation.py index 77c7d94f5..57a884496 100644 --- a/src/vak/nn/modules/activation.py +++ b/src/vak/nn/modules/activation.py @@ -12,6 +12,7 @@ class NormReLU(torch.nn.Module): Temporal convolutional networks for action segmentation and detection. In proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 156-165). """ + def __init__(self): super().__init__() self.relu = torch.nn.ReLU(inplace=True) diff --git a/src/vak/nn/modules/conv.py b/src/vak/nn/modules/conv.py index 3b935da9e..778e5abf3 100644 --- a/src/vak/nn/modules/conv.py +++ b/src/vak/nn/modules/conv.py @@ -1,6 +1,8 @@ """Modules that perform neural network convolutions.""" import torch from torch.nn import functional as F + + # NOTE: added 2023-03-04 # in this class, we detect when one extra pixel should be added on the bottom or right # and specifically pad those, see line 75, ``if rows_odd or cols_odd:``. @@ -17,16 +19,19 @@ class Conv2dTF(torch.nn.Conv2d): Note there are issues with SAME convolution as performed by Tensorflow. See https://gist.github.com/Yangqing/47772de7eb3d5dbbff50ffb0d7a98964. """ - PADDING_METHODS = ('VALID', 'SAME') + + PADDING_METHODS = ("VALID", "SAME") def __init__(self, *args, **kwargs): # remove 'padding' from ``kwargs`` to avoid bug in ``torch`` => 1.7.2 # see https://github.com/yardencsGitHub/tweetynet/issues/166 - kwargs_super = {k: v for k, v in kwargs.items() if k != 'padding'} + kwargs_super = {k: v for k, v in kwargs.items() if k != "padding"} super(Conv2dTF, self).__init__(*args, **kwargs_super) padding = kwargs.get("padding", "SAME") if not isinstance(padding, str): - raise TypeError(f"value for 'padding' argument should be a string, one of: {self.PADDING_METHODS}") + raise TypeError( + f"value for 'padding' argument should be a string, one of: {self.PADDING_METHODS}" + ) padding = padding.upper() if padding not in self.PADDING_METHODS: raise ValueError( @@ -40,7 +45,10 @@ def _compute_padding(self, input, dim): effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1 out_size = (input_size + self.stride[dim] - 1) // self.stride[dim] total_padding = max( - 0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size + 0, + (out_size - 1) * self.stride[dim] + + effective_filter_size + - input_size, ) additional_padding = int(total_padding % 2 != 0) diff --git a/src/vak/plot/__init__.py b/src/vak/plot/__init__.py index 91bc49300..29d094b09 100644 --- a/src/vak/plot/__init__.py +++ b/src/vak/plot/__init__.py @@ -3,6 +3,7 @@ from .spect import spect, spect_annot __all__ = [ + "annotation", "learncurve", "plot_labels", "plot_segments", diff --git a/src/vak/plot/annot.py b/src/vak/plot/annot.py index f0df64f8e..fca7294d0 100644 --- a/src/vak/plot/annot.py +++ b/src/vak/plot/annot.py @@ -1,7 +1,7 @@ """functions for plotting annotations for vocalizations""" import matplotlib.pyplot as plt -from matplotlib.collections import LineCollection import numpy as np +from matplotlib.collections import LineCollection def plot_segments(onsets, offsets, y=0.5, ax=None, line_kwargs=None): @@ -139,5 +139,9 @@ def annotation( segment_centers = np.array(segment_centers) plot_labels( - labels=labels, t=segment_centers, y=y_labels, ax=ax, text_kwargs=text_kwargs + labels=labels, + t=segment_centers, + y=y_labels, + ax=ax, + text_kwargs=text_kwargs, ) diff --git a/src/vak/plot/learncurve.py b/src/vak/plot/learncurve.py index 530f96b80..c78cf0792 100644 --- a/src/vak/plot/learncurve.py +++ b/src/vak/plot/learncurve.py @@ -1,16 +1,18 @@ """functions to plot learning curve results""" -from glob import glob import os import pickle from configparser import ConfigParser +from glob import glob import joblib -import numpy as np import matplotlib.pyplot as plt +import numpy as np def get_all_results_list(root, these_dirs_branch, config_file): - these_dirs = [(root + this_dir[0], this_dir[1]) for this_dir in these_dirs_branch] + these_dirs = [ + (root + this_dir[0], this_dir[1]) for this_dir in these_dirs_branch + ] config = ConfigParser() all_results_list = [] @@ -24,13 +26,16 @@ def get_all_results_list(root, these_dirs_branch, config_file): results_dict["time_steps"] = config["NETWORK"]["time_steps"] results_dict["num_hidden"] = config["NETWORK"]["num_hidden"] results_dict["train_set_durs"] = [ - int(element) for element in config["TRAIN"]["train_set_durs"].split(",") + int(element) + for element in config["TRAIN"]["train_set_durs"].split(",") ] with open(glob("summary*/train_err")[0], "rb") as f: results_dict["train_err"] = pickle.load(f) with open(glob("summary*/test_err")[0], "rb") as f: results_dict["test_err"] = pickle.load(f) - pe = joblib.load(glob("summary*/y_preds_and_err_for_train_and_test")[0]) + pe = joblib.load( + glob("summary*/y_preds_and_err_for_train_and_test")[0] + ) results_dict["train_syl_err_rate"] = pe["train_syl_err_rate"] results_dict["test_syl_err_rate"] = pe["test_syl_err_rate"] results_dict["bird_ID"] = bird_ID @@ -78,7 +83,9 @@ def frame_error_rate(all_results_list): plt.legend(fontsize=20) plt.xticks(el["train_set_durs"]) plt.tick_params(axis="both", which="major", labelsize=20, rotation=45) - plt.title("Frame error rate as a function of training set size", fontsize=40) + plt.title( + "Frame error rate as a function of training set size", fontsize=40 + ) plt.ylabel("Frame error rate\nas measured on test set", fontsize=32) plt.xlabel("Training set size: duration in s", fontsize=32) plt.tight_layout() @@ -115,7 +122,9 @@ def syllable_error_rate(all_results_list): ) plt.legend(fontsize=20, loc="upper right") - plt.title("Syllable error rate as a function of training set size", fontsize=40) + plt.title( + "Syllable error rate as a function of training set size", fontsize=40 + ) plt.xticks(el["train_set_durs"]) plt.tick_params(axis="both", which="major", labelsize=20, rotation=45) plt.ylabel("Syllable error rate\nas measured on test set", fontsize=32) diff --git a/src/vak/plot/spect.py b/src/vak/plot/spect.py index 0b57078d1..286e357dd 100644 --- a/src/vak/plot/spect.py +++ b/src/vak/plot/spect.py @@ -112,7 +112,12 @@ def spect_annot( spect(s, t, f, tlim, flim, ax=spect_ax, imshow_kwargs=imshow_kwargs) annotation( - annot, t, tlim, ax=annot_ax, line_kwargs=line_kwargs, text_kwargs=text_kwargs + annot, + t, + tlim, + ax=annot_ax, + line_kwargs=line_kwargs, + text_kwargs=text_kwargs, ) return fig, spect_ax, annot_ax diff --git a/src/vak/predict/__init__.py b/src/vak/predict/__init__.py index c83e04d09..f9bbd3aec 100644 --- a/src/vak/predict/__init__.py +++ b/src/vak/predict/__init__.py @@ -1,5 +1,9 @@ -from .predict import * +from .predict import predict +from . import frame_classification, parametric_umap + __all__ = [ + "frame_classification", + "parametric_umap", "predict", ] diff --git a/src/vak/predict/frame_classification.py b/src/vak/predict/frame_classification.py index 5c4a7f63c..99b5bf99d 100644 --- a/src/vak/predict/frame_classification.py +++ b/src/vak/predict/frame_classification.py @@ -8,24 +8,15 @@ import crowsetta import joblib -import pytorch_lightning as lightning import numpy as np -from tqdm import tqdm +import pytorch_lightning as lightning import torch.utils.data +from tqdm import tqdm -from .. import ( - datasets, - models, - transforms -) -from ..common import ( - constants, - files, - validators -) -from ..datasets.frame_classification import FramesDataset +from .. import datasets, models, transforms +from ..common import constants, files, validators from ..common.device import get_default as get_default_device - +from ..datasets.frame_classification import FramesDataset logger = logging.getLogger(__name__) @@ -119,8 +110,8 @@ def predict_with_frame_classification_model( will be `gy6or6_032312_081416.tweetynet.output.npz`. """ for path, path_name in zip( - (checkpoint_path, labelmap_path, spect_scaler_path), - ('checkpoint_path', 'labelmap_path', 'spect_scaler_path'), + (checkpoint_path, labelmap_path, spect_scaler_path), + ("checkpoint_path", "labelmap_path", "spect_scaler_path"), ): if path is not None: if not validators.is_a_file(path): @@ -152,33 +143,35 @@ def predict_with_frame_classification_model( logger.info(f"loading SpectScaler from path: {spect_scaler_path}") spect_standardizer = joblib.load(spect_scaler_path) else: - logger.info(f"Not loading SpectScaler, no path was specified") + logger.info("Not loading SpectScaler, no path was specified") spect_standardizer = None if transform_params is None: transform_params = {} - transform_params.update({'spect_standardizer': spect_standardizer}) + transform_params.update({"spect_standardizer": spect_standardizer}) item_transform = transforms.defaults.get_default_transform( - model_name, - "predict", - transform_params + model_name, "predict", transform_params ) logger.info(f"loading labelmap from path: {labelmap_path}") with labelmap_path.open("r") as f: labelmap = json.load(f) - metadata = datasets.frame_classification.Metadata.from_dataset_path(dataset_path) + metadata = datasets.frame_classification.Metadata.from_dataset_path( + dataset_path + ) dataset_csv_path = dataset_path / metadata.dataset_csv_filename - logger.info(f"loading dataset to predict from csv path: {dataset_csv_path}") + logger.info( + f"loading dataset to predict from csv path: {dataset_csv_path}" + ) if dataset_params is None: dataset_params = {} pred_dataset = FramesDataset.from_dataset_path( dataset_path=dataset_path, split="predict", item_transform=item_transform, - **dataset_params + **dataset_params, ) pred_loader = torch.utils.data.DataLoader( @@ -191,11 +184,17 @@ def predict_with_frame_classification_model( # ---------------- set up to convert predictions to annotation files ----------------------------------------------- if annot_csv_filename is None: - annot_csv_filename = pathlib.Path(dataset_path).stem + constants.ANNOT_CSV_SUFFIX + annot_csv_filename = ( + pathlib.Path(dataset_path).stem + constants.ANNOT_CSV_SUFFIX + ) annot_csv_path = pathlib.Path(output_dir).joinpath(annot_csv_filename) logger.info(f"will save annotations in .csv file: {annot_csv_path}") - metadata = datasets.frame_classification.metadata.Metadata.from_dataset_path(dataset_path) + metadata = ( + datasets.frame_classification.metadata.Metadata.from_dataset_path( + dataset_path + ) + ) frame_dur = metadata.frame_dur logger.info( f"Duration of a frame in dataset, in seconds: {frame_dur}", @@ -207,7 +206,9 @@ def predict_with_frame_classification_model( # throw out the window dimension; just want to tell network (channels, height, width) shape if len(input_shape) == 4: input_shape = input_shape[1:] - logger.info(f"Shape of input to networks used for predictions: {input_shape}") + logger.info( + f"Shape of input to networks used for predictions: {input_shape}" + ) logger.info(f"instantiating model from config:/n{model_name}") @@ -220,16 +221,16 @@ def predict_with_frame_classification_model( ) # ---------------- do the actual predicting -------------------------------------------------------------------- - logger.info(f"loading checkpoint for {model_name} from path: {checkpoint_path}") + logger.info( + f"loading checkpoint for {model_name} from path: {checkpoint_path}" + ) model.load_state_dict_from_path(checkpoint_path) - if device == 'cuda': - accelerator = 'gpu' + if device == "cuda": + accelerator = "gpu" else: accelerator = None - trainer_logger = lightning.loggers.TensorBoardLogger( - save_dir=output_dir - ) + trainer_logger = lightning.loggers.TensorBoardLogger(save_dir=output_dir) trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger) logger.info(f"running predict method of {model_name}") @@ -243,10 +244,12 @@ def predict_with_frame_classification_model( # ---------------- converting to annotations ------------------------------------------------------------------ progress_bar = tqdm(pred_loader) - input_type = metadata.input_type # we use this to get frame_times inside loop - if input_type == 'audio': + input_type = ( + metadata.input_type + ) # we use this to get frame_times inside loop + if input_type == "audio": audio_format = metadata.audio_format - elif input_type == 'spect': + elif input_type == "spect": spect_format = metadata.spect_format annots = [] logger.info("converting predictions to annotations") @@ -266,18 +269,23 @@ def predict_with_frame_classification_model( net_output = net_output[:, padding_mask] net_output = net_output.cpu().numpy() net_output_path = output_dir.joinpath( - pathlib.Path(source_path).stem + f"{model_name}{constants.NET_OUTPUT_SUFFIX}" + pathlib.Path(source_path).stem + + f"{model_name}{constants.NET_OUTPUT_SUFFIX}" ) np.savez(net_output_path, net_output) y_pred = torch.argmax(y_pred, dim=1) # assumes class dimension is 1 y_pred = torch.flatten(y_pred).cpu().numpy()[padding_mask] - if input_type == 'audio': - frames, samplefreq = constants.AUDIO_FORMAT_FUNC_MAP[audio_format](source_path) + if input_type == "audio": + frames, samplefreq = constants.AUDIO_FORMAT_FUNC_MAP[audio_format]( + source_path + ) frame_times = np.arange(frames.shape[-1]) / samplefreq - elif input_type == 'spect': - spect_dict = files.spect.load(dataset_path / source_path, spect_format=spect_format) + elif input_type == "spect": + spect_dict = files.spect.load( + dataset_path / source_path, spect_format=spect_format + ) frame_times = spect_dict[timebins_key] if majority_vote or min_segment_dur: diff --git a/src/vak/predict/parametric_umap.py b/src/vak/predict/parametric_umap.py index 4b056b289..df7eba8a0 100644 --- a/src/vak/predict/parametric_umap.py +++ b/src/vak/predict/parametric_umap.py @@ -8,18 +8,10 @@ import pytorch_lightning as lightning import torch.utils.data -from .. import ( - datasets, - models, - transforms -) -from ..common import ( - constants, - validators -) -from ..datasets.parametric_umap import ParametricUMAPDataset +from .. import datasets, models, transforms +from ..common import validators from ..common.device import get_default as get_default_device - +from ..datasets.parametric_umap import ParametricUMAPDataset logger = logging.getLogger(__name__) @@ -75,8 +67,8 @@ def predict_with_parametric_umap_model( should be saved. Defaults to current working directory. """ for path, path_name in zip( - (checkpoint_path,), - ('checkpoint_path',), + (checkpoint_path,), + ("checkpoint_path",), ): if path is not None: if not validators.is_a_file(path): @@ -92,7 +84,9 @@ def predict_with_parametric_umap_model( logger.info( f"Loading metadata from dataset path: {dataset_path}", ) - metadata = datasets.frame_classification.Metadata.from_dataset_path(dataset_path) + metadata = datasets.frame_classification.Metadata.from_dataset_path( + dataset_path + ) if output_dir is None: output_dir = pathlib.Path(os.getcwd()) @@ -110,18 +104,18 @@ def predict_with_parametric_umap_model( # ---------------- load data for prediction ------------------------------------------------------------------------ if transform_params is None: transform_params = {} - if 'padding' not in transform_params and model_name == 'ConvEncoderUMAP': + if "padding" not in transform_params and model_name == "ConvEncoderUMAP": padding = models.convencoder_umap.get_default_padding(metadata.shape) - transform_params['padding'] = padding + transform_params["padding"] = padding item_transform = transforms.defaults.get_default_transform( - model_name, - "predict", - transform_params + model_name, "predict", transform_params ) dataset_csv_path = dataset_path / metadata.dataset_csv_filename - logger.info(f"loading dataset to predict from csv path: {dataset_csv_path}") + logger.info( + f"loading dataset to predict from csv path: {dataset_csv_path}" + ) if dataset_params is None: dataset_params = {} @@ -129,7 +123,7 @@ def predict_with_parametric_umap_model( dataset_path=dataset_path, split="predict", transform=item_transform, - **dataset_params + **dataset_params, ) pred_loader = torch.utils.data.DataLoader( @@ -140,19 +134,15 @@ def predict_with_parametric_umap_model( num_workers=num_workers, ) - # ---------------- set up to convert predictions to annotation files ----------------------------------------------- - if annot_csv_filename is None: - annot_csv_filename = pathlib.Path(dataset_path).stem + constants.ANNOT_CSV_SUFFIX - annot_csv_path = pathlib.Path(output_dir).joinpath(annot_csv_filename) - logger.info(f"will save annotations in .csv file: {annot_csv_path}") - # ---------------- do the actual predicting + converting to annotations -------------------------------------------- input_shape = pred_dataset.shape # if dataset returns spectrogram reshaped into windows, # throw out the window dimension; just want to tell network (channels, height, width) shape if len(input_shape) == 4: input_shape = input_shape[1:] - logger.info(f"Shape of input to networks used for predictions: {input_shape}") + logger.info( + f"Shape of input to networks used for predictions: {input_shape}" + ) logger.info(f"instantiating model from config:/n{model_name}") @@ -163,24 +153,24 @@ def predict_with_parametric_umap_model( ) # ---------------- do the actual predicting -------------------------------------------------------------------- - logger.info(f"loading checkpoint for {model_name} from path: {checkpoint_path}") + logger.info( + f"loading checkpoint for {model_name} from path: {checkpoint_path}" + ) model.load_state_dict_from_path(checkpoint_path) - if device == 'cuda': - accelerator = 'gpu' + if device == "cuda": + accelerator = "gpu" else: accelerator = None - trainer_logger = lightning.loggers.TensorBoardLogger( - save_dir=output_dir - ) + trainer_logger = lightning.loggers.TensorBoardLogger(save_dir=output_dir) trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger) logger.info(f"running predict method of {model_name}") - results = trainer.predict(model, pred_loader) - - eval_df = pd.DataFrame(row, index=[0]) - eval_csv_path = output_dir.joinpath(f"eval_{model_name}_{timenow}.csv") - logger.info(f"saving csv with evaluation metrics at: {eval_csv_path}") - eval_df.to_csv( - eval_csv_path, index=False - ) # index is False to avoid having "Unnamed: 0" column when loading + results = trainer.predict(model, pred_loader) # noqa : F841 + + # eval_df = pd.DataFrame(row, index=[0]) + # eval_csv_path = output_dir.joinpath(f"eval_{model_name}_{timenow}.csv") + # logger.info(f"saving csv with evaluation metrics at: {eval_csv_path}") + # eval_df.to_csv( + # eval_csv_path, index=False + # ) # index is False to avoid having "Unnamed: 0" column when loading diff --git a/src/vak/predict/predict.py b/src/vak/predict/predict.py index 755600aed..29208d0f2 100644 --- a/src/vak/predict/predict.py +++ b/src/vak/predict/predict.py @@ -5,15 +5,10 @@ import os import pathlib -from .frame_classification import predict_with_frame_classification_model -from .. import ( - models -) -from ..common import ( - validators -) +from .. import models +from ..common import validators from ..common.device import get_default as get_default_device - +from .frame_classification import predict_with_frame_classification_model logger = logging.getLogger(__name__) @@ -31,7 +26,7 @@ def predict( spect_scaler_path: str | pathlib.Path | None = None, device: str | None = None, annot_csv_filename: str | None = None, - output_dir: str | pathlib.Path | None = None, + output_dir: str | pathlib.Path | None = None, min_segment_dur: float | None = None, majority_vote: bool = False, save_net_outputs: bool = False, @@ -108,8 +103,8 @@ def predict( will be `gy6or6_032312_081416.tweetynet.output.npz`. """ for path, path_name in zip( - (checkpoint_path, labelmap_path, spect_scaler_path), - ('checkpoint_path', 'labelmap_path', 'spect_scaler_path'), + (checkpoint_path, labelmap_path, spect_scaler_path), + ("checkpoint_path", "labelmap_path", "spect_scaler_path"), ): if path is not None: if not validators.is_a_file(path): @@ -162,6 +157,4 @@ def predict( save_net_outputs=save_net_outputs, ) else: - raise ValueError( - f"Model family not recognized: {model_family}" - ) + raise ValueError(f"Model family not recognized: {model_family}") diff --git a/src/vak/prep/__init__.py b/src/vak/prep/__init__.py index f22419500..832630976 100644 --- a/src/vak/prep/__init__.py +++ b/src/vak/prep/__init__.py @@ -9,14 +9,13 @@ ) from .prep import prep - __all__ = [ - 'audio_dataset', - 'constants', - 'dataset_df_helper', - 'dimensionality_reduction', - 'frame_classification', - 'prep', - 'spectrogram_dataset', - 'unit_dataset', + "audio_dataset", + "constants", + "dataset_df_helper", + "frame_classification", + "parametric_umap", + "prep", + "spectrogram_dataset", + "unit_dataset", ] diff --git a/src/vak/prep/audio_dataset.py b/src/vak/prep/audio_dataset.py index dc8ff982a..04422c4b6 100644 --- a/src/vak/prep/audio_dataset.py +++ b/src/vak/prep/audio_dataset.py @@ -5,15 +5,14 @@ import crowsetta import dask.bag as db -from dask.diagnostics import ProgressBar import numpy as np import pandas as pd +from dask.diagnostics import ProgressBar from ..common import annotation, constants from ..common.converters import expanded_user_path, labelset_to_set from .spectrogram_dataset.audio_helper import files_from_dir - logger = logging.getLogger(__name__) @@ -89,7 +88,10 @@ def prep_audio_dataset( annot_dir=data_dir, annot_format=annot_format ) scribe = crowsetta.Transcriber(format=annot_format) - annot_list = [scribe.from_file(annot_file).to_annot() for annot_file in annot_files] + annot_list = [ + scribe.from_file(annot_file).to_annot() + for annot_file in annot_files + ] else: scribe = crowsetta.Transcriber(format=annot_format) annot_list = scribe.from_file(annot_file).to_annot() @@ -101,10 +103,14 @@ def prep_audio_dataset( annot_list = None if annot_list: - audio_annot_map = annotation.map_annotated_to_annot(audio_files, annot_list, annot_format) + audio_annot_map = annotation.map_annotated_to_annot( + audio_files, annot_list, annot_format + ) else: # no annotation, so map spectrogram files to None - audio_annot_map = dict((audio_path, None) for audio_path in audio_files) + audio_annot_map = dict( + (audio_path, None) for audio_path in audio_files + ) # use mapping (if generated/supplied) with labelset, if supplied, to filter if labelset: # then remove annotations with labels not in labelset @@ -130,8 +136,10 @@ def _to_record(audio_annot_tuple): Accepts a two-element tuple containing (1) a dictionary that represents a spectrogram and (2) annotation for that file""" audio_path, annot = audio_annot_tuple - dat, samplerate = constants.AUDIO_FORMAT_FUNC_MAP[audio_format](audio_path) - sample_dur = 1. / samplerate + dat, samplerate = constants.AUDIO_FORMAT_FUNC_MAP[audio_format]( + audio_path + ) + sample_dur = 1.0 / samplerate audio_dur = dat.shape[-1] * sample_dur if annot is not None: @@ -149,7 +157,9 @@ def abspath(a_path): [ abspath(audio_path), abspath(annot_path), - annot_format if annot_format else constants.NO_ANNOTATION_FORMAT, + annot_format + if annot_format + else constants.NO_ANNOTATION_FORMAT, samplerate, sample_dur, audio_dur, diff --git a/src/vak/prep/constants.py b/src/vak/prep/constants.py index b828ec815..68399dd4c 100644 --- a/src/vak/prep/constants.py +++ b/src/vak/prep/constants.py @@ -2,11 +2,7 @@ Defined in a separate module to minimize circular imports. """ -from . import ( - parametric_umap, - frame_classification -) - +from . import frame_classification, parametric_umap VALID_PURPOSES = frozenset( [ @@ -17,11 +13,11 @@ ] ) -INPUT_TYPES = {'audio', 'spect'} +INPUT_TYPES = {"audio", "spect"} DATASET_TYPE_FUNCTION_MAP = { - 'frame classification': frame_classification.prep_frame_classification_dataset, - 'parametric umap': parametric_umap.prep_parametric_umap_dataset, + "frame classification": frame_classification.prep_frame_classification_dataset, + "parametric umap": parametric_umap.prep_parametric_umap_dataset, } DATASET_TYPES = tuple(DATASET_TYPE_FUNCTION_MAP.keys()) diff --git a/src/vak/prep/dataset_df_helper.py b/src/vak/prep/dataset_df_helper.py index 2c20ffe07..81f73f9ba 100644 --- a/src/vak/prep/dataset_df_helper.py +++ b/src/vak/prep/dataset_df_helper.py @@ -31,7 +31,9 @@ def get_dataset_csv_filename(data_dir_name: str, timenow: str) -> str: return f"{data_dir_name}_prep_{timenow}.csv" -def get_dataset_csv_path(dataset_path: pathlib.Path, data_dir_name: str, timenow: str) -> pathlib.Path: +def get_dataset_csv_path( + dataset_path: pathlib.Path, data_dir_name: str, timenow: str +) -> pathlib.Path: """Returns the path that should be used to save a pandas DataFrame representing a dataset to a csv file. diff --git a/src/vak/prep/frame_classification/__init__.py b/src/vak/prep/frame_classification/__init__.py index 57a457f19..f9ee48004 100644 --- a/src/vak/prep/frame_classification/__init__.py +++ b/src/vak/prep/frame_classification/__init__.py @@ -1,16 +1,10 @@ +from . import dataset_arrays, frame_classification, learncurve, validators from .frame_classification import prep_frame_classification_dataset -from . import ( - dataset_arrays, - frame_classification, - learncurve, - validators, -) - __all__ = [ - 'dataset_arrays', - 'frame_classification', - 'learncurve', - 'prep_frame_classification_dataset', - 'validators', + "dataset_arrays", + "frame_classification", + "learncurve", + "prep_frame_classification_dataset", + "validators", ] diff --git a/src/vak/prep/frame_classification/dataset_arrays.py b/src/vak/prep/frame_classification/dataset_arrays.py index 7a8a31447..4e6744288 100644 --- a/src/vak/prep/frame_classification/dataset_arrays.py +++ b/src/vak/prep/frame_classification/dataset_arrays.py @@ -9,55 +9,44 @@ import attrs import crowsetta import dask.bag as db -from dask.diagnostics import ProgressBar import numpy as np import pandas as pd +from dask.diagnostics import ProgressBar +from ... import common, datasets, transforms from .. import constants as prep_constants -from ... import ( - common, - datasets, - transforms -) - logger = logging.getLogger(__name__) -def argsort_by_label_freq( - annots: list[crowsetta.Annotation] -) -> list[int]: +def argsort_by_label_freq(annots: list[crowsetta.Annotation]) -> list[int]: """Returns indices to sort a list of annotations - in order of more frequently appearing labels, - i.e., the first annotation will have the label - that appears least frequently and the last annotation - will have the label that appears most frequently. - - Used to sort a dataframe representing a dataset of annotated audio - or spectrograms before cropping that dataset to a specified duration, - so that it's less likely that cropping will remove all occurrences - of any label class from the total dataset. - - Parameters - ---------- - annots: list - List of :class:`crowsetta.Annotation` instances. - - Returns - ------- - sort_inds: list - Integer values to sort ``annots``. + in order of more frequently appearing labels, + i.e., the first annotation will have the label + that appears least frequently and the last annotation + will have the label that appears most frequently. + + Used to sort a dataframe representing a dataset of annotated audio + or spectrograms before cropping that dataset to a specified duration, + so that it's less likely that cropping will remove all occurrences + of any label class from the total dataset. + + Parameters + ---------- + annots: list + List of :class:`crowsetta.Annotation` instances. + + Returns + ------- + sort_inds: list + Integer values to sort ``annots``. """ - all_labels = [ - lbl for annot in annots for lbl in annot.seq.labels - ] + all_labels = [lbl for annot in annots for lbl in annot.seq.labels] label_counts = collections.Counter(all_labels) sort_inds = [] # make indices ahead of time so they stay constant as we remove things from the list - ind_annot_tuples = list( - enumerate(copy.deepcopy(annots)) - ) + ind_annot_tuples = list(enumerate(copy.deepcopy(annots))) for label, _ in reversed(label_counts.most_common()): # next line, [:] to make a temporary copy to avoid remove bug for ind_annot_tuple in ind_annot_tuples[:]: @@ -79,9 +68,7 @@ def argsort_by_label_freq( f"Left over (with indices from list): {ind_annot_tuples}" ) - if not ( - sorted(sort_inds) == list(range(len(annots))) - ): + if not (sorted(sort_inds) == list(range(len(annots)))): raise ValueError( "sorted(sort_inds) does not equal range(len(annots)):" f"sort_inds: {sort_inds}\nrange(len(annots)): {list(range(len(annots)))}" @@ -92,13 +79,14 @@ def argsort_by_label_freq( @attrs.define(frozen=True) class Sample: - """Dataclass representing one sample + """Dataclass representing one sample in a frame classification dataset. - - Used to add paths for arrays from the sample - to a ``dataset_df``, and to build - the ``sample_ids`` vector and ``inds_in_sample`` vector + + Used to add paths for arrays from the sample + to a ``dataset_df``, and to build + the ``sample_ids`` vector and ``inds_in_sample`` vector for the entire dataset.""" + source_id: int = attrs.field() frame_npy_path: str frame_labels_npy_path: str @@ -107,14 +95,14 @@ class Sample: def make_npy_files_for_each_split( - dataset_df: pd.DataFrame, - dataset_path: str | pathlib.Path, - input_type: str, - purpose: str, - labelmap: dict, - audio_format: str, - spect_key: str = 's', - timebins_key: str = 't', + dataset_df: pd.DataFrame, + dataset_path: str | pathlib.Path, + input_type: str, + purpose: str, + labelmap: dict, + audio_format: str, + spect_key: str = "s", + timebins_key: str = "t", ): r"""Make npy files containing arrays for each split of a frame classification dataset. @@ -204,32 +192,34 @@ def make_npy_files_for_each_split( split_df = dataset_df[dataset_df.split == split].copy() - if purpose != 'predict': + if purpose != "predict": annots = common.annotation.from_df(split_df) else: annots = None if annots: sort_inds = argsort_by_label_freq(annots) - split_df['sort_inds'] = sort_inds - split_df = split_df.sort_values(by='sort_inds').drop(columns='sort_inds').reset_index() + split_df["sort_inds"] = sort_inds + split_df = ( + split_df.sort_values(by="sort_inds") + .drop(columns="sort_inds") + .reset_index() + ) - if input_type == 'audio': - source_paths = split_df['audio_path'].values - elif input_type == 'spect': - source_paths = split_df['spect_path'].values + if input_type == "audio": + source_paths = split_df["audio_path"].values + elif input_type == "spect": + source_paths = split_df["spect_path"].values else: - raise ValueError( - f"Invalid ``input_type``: {input_type}" - ) + raise ValueError(f"Invalid ``input_type``: {input_type}") # do this *again* after sorting the dataframe - if purpose != 'predict': + if purpose != "predict": annots = common.annotation.from_df(split_df) else: annots = None def _save_dataset_arrays_and_return_index_arrays( - source_id_path_annot_tup + source_id_path_annot_tup, ): """Function we use with dask to parallelize @@ -238,19 +228,24 @@ def _save_dataset_arrays_and_return_index_arrays( source_id, source_path, annot = source_id_path_annot_tup source_path = pathlib.Path(source_path) - if input_type == 'audio': - frames, samplefreq = common.constants.AUDIO_FORMAT_FUNC_MAP[audio_format](source_path) - if audio_format == 'cbin': # convert to ~wav, from int16 to float64 + if input_type == "audio": + frames, samplefreq = common.constants.AUDIO_FORMAT_FUNC_MAP[ + audio_format + ](source_path) + if ( + audio_format == "cbin" + ): # convert to ~wav, from int16 to float64 frames = frames.astype(np.float64) / 32768.0 if annot: frame_times = np.arange(frames.shape[-1]) / samplefreq - elif input_type == 'spect': + elif input_type == "spect": spect_dict = np.load(source_path) frames = spect_dict[spect_key] if annot: frame_times = spect_dict[timebins_key] frames_npy_path = split_subdir / ( - source_path.stem + datasets.frame_classification.constants.FRAMES_ARRAY_EXT + source_path.stem + + datasets.frame_classification.constants.FRAMES_ARRAY_EXT ) np.save(frames_npy_path, frames) frames_npy_path = str( @@ -273,7 +268,8 @@ def _save_dataset_arrays_and_return_index_arrays( unlabeled_label=labelmap["unlabeled"], ) frame_labels_npy_path = split_subdir / ( - source_path.stem + datasets.frame_classification.constants.FRAME_LABELS_EXT + source_path.stem + + datasets.frame_classification.constants.FRAME_LABELS_EXT ) np.save(frame_labels_npy_path, frame_labels) frame_labels_npy_path = str( @@ -288,7 +284,7 @@ def _save_dataset_arrays_and_return_index_arrays( frames_npy_path, frame_labels_npy_path, sample_id_vec, - inds_in_sample_vec + inds_in_sample_vec, ) # ---- make npy files for this split, parallelized with dask @@ -296,7 +292,9 @@ def _save_dataset_arrays_and_return_index_arrays( if annots: source_path_annot_tups = [ (source_id, source_path, annot) - for source_id, (source_path, annot) in enumerate(zip(source_paths, annots)) + for source_id, (source_path, annot) in enumerate( + zip(source_paths, annots) + ) ] else: source_path_annot_tups = [ @@ -306,9 +304,11 @@ def _save_dataset_arrays_and_return_index_arrays( source_path_annot_bag = db.from_sequence(source_path_annot_tups) with ProgressBar(): - samples = list(source_path_annot_bag.map( - _save_dataset_arrays_and_return_index_arrays - )) + samples = list( + source_path_annot_bag.map( + _save_dataset_arrays_and_return_index_arrays + ) + ) samples = sorted(samples, key=lambda sample: sample.source_id) # ---- save indexing vectors in split directory @@ -316,24 +316,30 @@ def _save_dataset_arrays_and_return_index_arrays( list(sample.sample_id_vec for sample in samples) ) np.save( - split_subdir / datasets.frame_classification.constants.SAMPLE_IDS_ARRAY_FILENAME, sample_id_vec + split_subdir + / datasets.frame_classification.constants.SAMPLE_IDS_ARRAY_FILENAME, + sample_id_vec, ) inds_in_sample_vec = np.concatenate( list(sample.inds_in_sample_vec for sample in samples) ) np.save( - split_subdir / datasets.frame_classification.constants.INDS_IN_SAMPLE_ARRAY_FILENAME, inds_in_sample_vec + split_subdir + / datasets.frame_classification.constants.INDS_IN_SAMPLE_ARRAY_FILENAME, + inds_in_sample_vec, ) - frame_npy_paths = [ - str(sample.frame_npy_path) for sample in samples - ] - split_df[datasets.frame_classification.constants.FRAMES_NPY_PATH_COL_NAME] = frame_npy_paths + frame_npy_paths = [str(sample.frame_npy_path) for sample in samples] + split_df[ + datasets.frame_classification.constants.FRAMES_NPY_PATH_COL_NAME + ] = frame_npy_paths frame_labels_npy_paths = [ str(sample.frame_labels_npy_path) for sample in samples ] - split_df[datasets.frame_classification.constants.FRAME_LABELS_NPY_PATH_COL_NAME] = frame_labels_npy_paths + split_df[ + datasets.frame_classification.constants.FRAME_LABELS_NPY_PATH_COL_NAME + ] = frame_labels_npy_paths dataset_df_out.append(split_df) dataset_df_out = pd.concat(dataset_df_out) diff --git a/src/vak/prep/frame_classification/frame_classification.py b/src/vak/prep/frame_classification/frame_classification.py index a58c83d4f..620000297 100644 --- a/src/vak/prep/frame_classification/frame_classification.py +++ b/src/vak/prep/frame_classification/frame_classification.py @@ -5,18 +5,16 @@ import crowsetta.formats.seq -from . import dataset_arrays, validators -from .learncurve import make_learncurve_splits_from_dataset_df -from .. import dataset_df_helper, sequence_dataset, split -from ..audio_dataset import prep_audio_dataset -from ..spectrogram_dataset.prep import prep_spectrogram_dataset - from ... import datasets from ...common import labels from ...common.converters import expanded_user_path, labelset_to_set from ...common.logging import config_logging_for_cli, log_version from ...common.timenow import get_timenow_as_str - +from .. import dataset_df_helper, sequence_dataset, split +from ..audio_dataset import prep_audio_dataset +from ..spectrogram_dataset.prep import prep_spectrogram_dataset +from . import dataset_arrays, validators +from .learncurve import make_learncurve_splits_from_dataset_df logger = logging.getLogger(__name__) @@ -34,7 +32,7 @@ def prep_frame_classification_dataset( labelset: set | None = None, audio_dask_bag_kwargs: dict | None = None, train_dur: int | None = None, - val_dur: int | None =None, + val_dur: int | None = None, test_dur: int | None = None, train_set_durs: list[float] | None = None, num_replicates: int | None = None, @@ -142,19 +140,21 @@ def prep_frame_classification_dataset( f"Value for ``input_type`` was: {input_type}" ) - if input_type == 'audio' and spect_format is not None: + if input_type == "audio" and spect_format is not None: raise ValueError( f"Input type was 'audio' but a ``spect_format`` was specified: '{spect_format}'. " f"Please specify ``audio_format`` only." ) - if input_type == 'audio' and audio_format is None: + if input_type == "audio" and audio_format is None: raise ValueError( - f"Input type was 'audio' but no ``audio_format`` was specified. " + "Input type was 'audio' but no ``audio_format`` was specified. " ) if audio_format is None and spect_format is None: - raise ValueError("Must specify either ``audio_format`` or ``spect_format``") + raise ValueError( + "Must specify either ``audio_format`` or ``spect_format``" + ) if audio_format and spect_format: raise ValueError( @@ -168,7 +168,9 @@ def prep_frame_classification_dataset( data_dir = expanded_user_path(data_dir) if not data_dir.is_dir(): - raise NotADirectoryError(f"Path specified for ``data_dir`` not found: {data_dir}") + raise NotADirectoryError( + f"Path specified for ``data_dir`` not found: {data_dir}" + ) if output_dir: output_dir = expanded_user_path(output_dir) @@ -176,13 +178,15 @@ def prep_frame_classification_dataset( output_dir = data_dir if not output_dir.is_dir(): - raise NotADirectoryError(f"Path specified for ``output_dir`` not found: {output_dir}") + raise NotADirectoryError( + f"Path specified for ``output_dir`` not found: {output_dir}" + ) if annot_file is not None: annot_file = expanded_user_path(annot_file) if not annot_file.exists(): raise FileNotFoundError( - f'Path specified for ``annot_file`` not found: {annot_file}' + f"Path specified for ``annot_file`` not found: {annot_file}" ) if purpose == "predict": @@ -190,7 +194,7 @@ def prep_frame_classification_dataset( warnings.warn( "The ``purpose`` argument was set to 'predict`, but a ``labelset`` was provided." "This would cause an error because the ``prep_spectrogram_dataset`` section will attempt to " - f"check whether the files in the ``data_dir`` have labels in " + "check whether the files in the ``data_dir`` have labels in " "``labelset``, even though those files don't have annotation.\n" "Setting ``labelset`` to None." ) @@ -210,48 +214,56 @@ def prep_frame_classification_dataset( # ---- set up directory that will contain dataset, and csv file name ----------------------------------------------- data_dir_name = data_dir.name timenow = get_timenow_as_str() - dataset_path = output_dir / f'{data_dir_name}-vak-frame-classification-dataset-generated-{timenow}' + dataset_path = ( + output_dir + / f"{data_dir_name}-vak-frame-classification-dataset-generated-{timenow}" + ) dataset_path.mkdir() - if annot_file and annot_format == 'birdsong-recognition-dataset': + if annot_file and annot_format == "birdsong-recognition-dataset": # we do this normalization / canonicalization after we make dataset_path # so that we can put the new annot_file inside of dataset_path, instead of # making new files elsewhere on a user's system - logger.info("The ``annot_format`` argument was set to 'birdsong-recognition-format'; " - "this format requires the audio files for their sampling rate " - "to convert onset and offset times of birdsong syllables to seconds." - "Converting this format to 'generic-seq' now with the times in seconds, " - "so that the dataset prepared by vak will not require the audio files.") + logger.info( + "The ``annot_format`` argument was set to 'birdsong-recognition-format'; " + "this format requires the audio files for their sampling rate " + "to convert onset and offset times of birdsong syllables to seconds." + "Converting this format to 'generic-seq' now with the times in seconds, " + "so that the dataset prepared by vak will not require the audio files." + ) birdsongrec = crowsetta.formats.seq.BirdsongRec.from_file(annot_file) annots = birdsongrec.to_annot() # note we point `annot_file` at a new file we're about to make - annot_file = dataset_path / f'{annot_file.stem}.converted-to-generic-seq.csv' + annot_file = ( + dataset_path / f"{annot_file.stem}.converted-to-generic-seq.csv" + ) # and we remake Annotations here so that annot_path points to this new file, not the birdsong-rec Annotation.xml annots = [ - crowsetta.Annotation(seq=annot.seq, annot_path=annot_file, notated_path=annot.notated_path) + crowsetta.Annotation( + seq=annot.seq, + annot_path=annot_file, + notated_path=annot.notated_path, + ) for annot in annots ] generic_seq = crowsetta.formats.seq.GenericSeq(annots=annots) generic_seq.to_file(annot_file) # and we now change `annot_format` as well. Both these will get passed to io.prep_spectrogram_dataset - annot_format = 'generic-seq' + annot_format = "generic-seq" # NOTE we set up logging here (instead of cli) so the prep log is included in the dataset config_logging_for_cli( - log_dst=dataset_path, - log_stem="prep", - level="INFO", - force=True + log_dst=dataset_path, log_stem="prep", level="INFO", force=True ) log_version(logger) - dataset_csv_path = dataset_df_helper.get_dataset_csv_path(dataset_path, data_dir_name, timenow) - logger.info( - f"Will prepare dataset as directory: {dataset_path}" + dataset_csv_path = dataset_df_helper.get_dataset_csv_path( + dataset_path, data_dir_name, timenow ) + logger.info(f"Will prepare dataset as directory: {dataset_path}") # ---- actually make the dataset ----------------------------------------------------------------------------------- - if input_type == 'spect': + if input_type == "spect": dataset_df = prep_spectrogram_dataset( labelset=labelset, data_dir=data_dir, @@ -263,7 +275,7 @@ def prep_frame_classification_dataset( spect_output_dir=dataset_path, audio_dask_bag_kwargs=audio_dask_bag_kwargs, ) - elif input_type == 'audio': + elif input_type == "audio": dataset_df = prep_audio_dataset( audio_format=audio_format, data_dir=data_dir, @@ -297,7 +309,9 @@ def prep_frame_classification_dataset( "zero for test_dur (and val_dur, if a validation set will be used)" ) - if all([dur is None for dur in (train_dur, val_dur, test_dur)]) or purpose in ( + if all( + [dur is None for dur in (train_dur, val_dur, test_dur)] + ) or purpose in ( "eval", "predict", ): @@ -329,18 +343,26 @@ def prep_frame_classification_dataset( # ideally we would just say split=purpose in call to add_split_col, but # we have to special case, because "eval" looks for a 'test' split (not an "eval" split) if purpose == "eval": - split_name = "test" # 'split_name' to avoid name clash with split package + split_name = ( + "test" # 'split_name' to avoid name clash with split package + ) elif purpose == "predict": split_name = "predict" - dataset_df = dataset_df_helper.add_split_col(dataset_df, split=split_name) + dataset_df = dataset_df_helper.add_split_col( + dataset_df, split=split_name + ) # ---- create and save labelmap ------------------------------------------------------------------------------------ # we do this before creating array files since we need to load the labelmap to make frame label vectors - if purpose != 'predict': + if purpose != "predict": # TODO: add option to generate predict using existing dataset, so we can get labelmap from it - map_unlabeled_segments = sequence_dataset.has_unlabeled_segments(dataset_df) - labelmap = labels.to_map(labelset, map_unlabeled=map_unlabeled_segments) + map_unlabeled_segments = sequence_dataset.has_unlabeled_segments( + dataset_df + ) + labelmap = labels.to_map( + labelset, map_unlabeled=map_unlabeled_segments + ) logger.info( f"Number of classes in labelmap: {len(labelmap)}", ) @@ -363,7 +385,7 @@ def prep_frame_classification_dataset( ) # ---- if purpose is learncurve, additionally prep splits for that ------------------------------------------------- - if purpose == 'learncurve': + if purpose == "learncurve": dataset_df = make_learncurve_splits_from_dataset_df( dataset_df, input_type, @@ -377,9 +399,7 @@ def prep_frame_classification_dataset( ) # ---- save csv file that captures provenance of source data ------------------------------------------------------- - logger.info( - f"Saving dataset csv file: {dataset_csv_path}" - ) + logger.info(f"Saving dataset csv file: {dataset_csv_path}") dataset_df.to_csv( dataset_csv_path, index=False ) # index is False to avoid having "Unnamed: 0" column when loading @@ -387,10 +407,10 @@ def prep_frame_classification_dataset( # ---- save metadata ----------------------------------------------------------------------------------------------- frame_dur = validators.validate_and_get_frame_dur(dataset_df, input_type) - if input_type == 'spect' and spect_format != 'npz': + if input_type == "spect" and spect_format != "npz": # then change to npz since we canonicalize data so it's always npz arrays # We need this to be correct for other functions, e.g. predict when it loads spectrogram files - spect_format = 'npz' + spect_format = "npz" metadata = datasets.frame_classification.Metadata( dataset_csv_filename=str(dataset_csv_path.name), diff --git a/src/vak/prep/frame_classification/learncurve.py b/src/vak/prep/frame_classification/learncurve.py index 1424642cb..96392967f 100644 --- a/src/vak/prep/frame_classification/learncurve.py +++ b/src/vak/prep/frame_classification/learncurve.py @@ -8,12 +8,9 @@ import pandas as pd -from .dataset_arrays import ( - make_npy_files_for_each_split, -) -from .. import split from ... import common - +from .. import split +from .dataset_arrays import make_npy_files_for_each_split logger = logging.getLogger(__name__) @@ -97,30 +94,37 @@ def make_learncurve_splits_from_dataset_df( f"Subsetting training set for training set of duration: {train_dur}", ) for replicate_num in range(1, num_replicates + 1): - train_dur_replicate_split_name = common.learncurve.get_train_dur_replicate_split_name( - train_dur, replicate_num + train_dur_replicate_split_name = ( + common.learncurve.get_train_dur_replicate_split_name( + train_dur, replicate_num + ) ) train_dur_replicate_df = split.frame_classification_dataframe( # copy to avoid mutating original train_split_df - train_split_df.copy(), dataset_path, train_dur=train_dur, labelset=labelset + train_split_df.copy(), + dataset_path, + train_dur=train_dur, + labelset=labelset, ) # remove rows where split set to 'None' - train_dur_replicate_df = train_dur_replicate_df[train_dur_replicate_df.split == "train"] + train_dur_replicate_df = train_dur_replicate_df[ + train_dur_replicate_df.split == "train" + ] # next line, make split name in csv match the split name used for directory in dataset dir - train_dur_replicate_df['split'] = train_dur_replicate_split_name - train_dur_replicate_df['train_dur'] = train_dur - train_dur_replicate_df['replicate_num'] = replicate_num - all_train_durs_and_replicates_df.append( - train_dur_replicate_df - ) + train_dur_replicate_df["split"] = train_dur_replicate_split_name + train_dur_replicate_df["train_dur"] = train_dur + train_dur_replicate_df["replicate_num"] = replicate_num + all_train_durs_and_replicates_df.append(train_dur_replicate_df) - all_train_durs_and_replicates_df = pd.concat(all_train_durs_and_replicates_df) + all_train_durs_and_replicates_df = pd.concat( + all_train_durs_and_replicates_df + ) all_train_durs_and_replicates_df = make_npy_files_for_each_split( all_train_durs_and_replicates_df, dataset_path, input_type, - 'learncurve', # purpose + "learncurve", # purpose labelmap, audio_format, spect_key, diff --git a/src/vak/prep/frame_classification/validators.py b/src/vak/prep/frame_classification/validators.py index 60b02ca1a..91d56be7b 100644 --- a/src/vak/prep/frame_classification/validators.py +++ b/src/vak/prep/frame_classification/validators.py @@ -4,7 +4,9 @@ import pandas as pd -def validate_and_get_frame_dur(dataset_df: pd.DataFrame, input_type: str) -> float: +def validate_and_get_frame_dur( + dataset_df: pd.DataFrame, input_type: str +) -> float: """Validate that there is a single, unique value for the frame duration for all samples (audio signals / spectrograms) in a dataset. If so, return that value. @@ -36,9 +38,9 @@ def validate_and_get_frame_dur(dataset_df: pd.DataFrame, input_type: str) -> flo ) # TODO: handle possible KeyError here? - if input_type == 'audio': + if input_type == "audio": frame_dur = dataset_df["sample_dur"].unique() - elif input_type == 'spect': + elif input_type == "spect": frame_dur = dataset_df["timebin_dur"].unique() if len(frame_dur) > 1: diff --git a/src/vak/prep/parametric_umap/__init__.py b/src/vak/prep/parametric_umap/__init__.py index af55977fe..893b8659c 100644 --- a/src/vak/prep/parametric_umap/__init__.py +++ b/src/vak/prep/parametric_umap/__init__.py @@ -1,2 +1,8 @@ from . import dataset_arrays from .parametric_umap import prep_parametric_umap_dataset + + +__all__ = [ + "dataset_arrays", + "prep_parametric_umap_dataset", +] diff --git a/src/vak/prep/parametric_umap/dataset_arrays.py b/src/vak/prep/parametric_umap/dataset_arrays.py index 84ce2b108..67e224ae7 100644 --- a/src/vak/prep/parametric_umap/dataset_arrays.py +++ b/src/vak/prep/parametric_umap/dataset_arrays.py @@ -9,14 +9,12 @@ import pandas as pd - logger = logging.getLogger(__name__) def move_files_into_split_subdirs( - dataset_df: pd.DataFrame, - dataset_path: pathlib.Path, - purpose: str) -> None: + dataset_df: pd.DataFrame, dataset_path: pathlib.Path, purpose: str +) -> None: """Move npy files in dataset into sub-directories, one for each split in the dataset. This is run *after* calling :func:`vak.prep.unit_dataset.prep_unit_dataset` @@ -47,13 +45,15 @@ def move_files_into_split_subdirs( The ``DataFrame`` is modified in place as the files are moved, so nothing is returned. """ - moved_spect_paths = [] # to clean up after moving -- may be empty if we copy all spects (e.g., user generated) + moved_spect_paths = ( + [] + ) # to clean up after moving -- may be empty if we copy all spects (e.g., user generated) # ---- copy/move files into split sub-directories inside dataset directory # Next line, note we drop any na rows in the split column, since they don't belong to a split anyway split_names = sorted(dataset_df.split.dropna().unique()) for split_name in split_names: - if split_name == 'None': + if split_name == "None": # these are files that didn't get assigned to a split continue split_subdir = dataset_path / split_name @@ -63,7 +63,7 @@ def move_files_into_split_subdirs( split_spect_paths = [ # this just converts from string to pathlib.Path pathlib.Path(spect_path) - for spect_path in split_df['spect_path'].values + for spect_path in split_df["spect_path"].values ] is_in_dataset_dir = [ # if dataset_path is one of the parents of spect_path, we can move; otherwise, we copy @@ -88,13 +88,9 @@ def move_files_into_split_subdirs( new_spect_path = spect_path.rename( split_subdir / spect_path.name ) - moved_spect_paths.append( - spect_path - ) + moved_spect_paths.append(spect_path) else: # copy instead of moving - new_spect_path = shutil.copy( - src=spect_path, dst=split_subdir - ) + new_spect_path = shutil.copy(src=spect_path, dst=split_subdir) new_spect_paths.append( # rewrite paths relative to dataset directory's root, so dataset is portable @@ -102,15 +98,17 @@ def move_files_into_split_subdirs( ) # cast to str before rewrite so that type doesn't silently change for some rows - new_spect_paths = [str(new_spect_path) for new_spect_path in new_spect_paths] - dataset_df.loc[split_df.index, 'spect_path'] = new_spect_paths + new_spect_paths = [ + str(new_spect_path) for new_spect_path in new_spect_paths + ] + dataset_df.loc[split_df.index, "spect_path"] = new_spect_paths # ---- clean up after moving/copying ------------------------------------------------------------------------------- # remove any directories that we just emptied if moved_spect_paths: - unique_parents = set([ - moved_spect.parent for moved_spect in moved_spect_paths - ]) + unique_parents = set( + [moved_spect.parent for moved_spect in moved_spect_paths] + ) for parent in unique_parents: if len(list(parent.iterdir())) < 1: shutil.rmtree(parent) diff --git a/src/vak/prep/parametric_umap/parametric_umap.py b/src/vak/prep/parametric_umap/parametric_umap.py index 39350669a..7b5be8fa8 100644 --- a/src/vak/prep/parametric_umap/parametric_umap.py +++ b/src/vak/prep/parametric_umap/parametric_umap.py @@ -5,16 +5,14 @@ import crowsetta -from . import dataset_arrays -from .. import dataset_df_helper, split -from ..unit_dataset import prep_unit_dataset - from ... import datasets from ...common import labels from ...common.converters import expanded_user_path, labelset_to_set from ...common.logging import config_logging_for_cli, log_version from ...common.timenow import get_timenow_as_str - +from .. import dataset_df_helper, split +from ..unit_dataset import prep_unit_dataset +from . import dataset_arrays logger = logging.getLogger(__name__) @@ -116,7 +114,9 @@ def prep_parametric_umap_dataset( data_dir = expanded_user_path(data_dir) if not data_dir.is_dir(): - raise NotADirectoryError(f"Path specified for ``data_dir`` not found: {data_dir}") + raise NotADirectoryError( + f"Path specified for ``data_dir`` not found: {data_dir}" + ) if output_dir: output_dir = expanded_user_path(output_dir) @@ -124,13 +124,15 @@ def prep_parametric_umap_dataset( output_dir = data_dir if not output_dir.is_dir(): - raise NotADirectoryError(f"Path specified for ``output_dir`` not found: {output_dir}") + raise NotADirectoryError( + f"Path specified for ``output_dir`` not found: {output_dir}" + ) if annot_file is not None: annot_file = expanded_user_path(annot_file) if not annot_file.exists(): raise FileNotFoundError( - f'Path specified for ``annot_file`` not found: {annot_file}' + f"Path specified for ``annot_file`` not found: {annot_file}" ) if purpose == "predict": @@ -138,7 +140,7 @@ def prep_parametric_umap_dataset( warnings.warn( "The ``purpose`` argument was set to 'predict`, but a ``labelset`` was provided." "This would cause an error because the ``prep_spectrogram_dataset`` section will attempt to " - f"check whether the files in the ``data_dir`` have labels in " + "check whether the files in the ``data_dir`` have labels in " "``labelset``, even though those files don't have annotation.\n" "Setting ``labelset`` to None." ) @@ -158,45 +160,53 @@ def prep_parametric_umap_dataset( # ---- set up directory that will contain dataset, and csv file name ----------------------------------------------- data_dir_name = data_dir.name timenow = get_timenow_as_str() - dataset_path = output_dir / f'{data_dir_name}-vak-dimensionality-reduction-dataset-generated-{timenow}' + dataset_path = ( + output_dir + / f"{data_dir_name}-vak-dimensionality-reduction-dataset-generated-{timenow}" + ) dataset_path.mkdir() - if annot_file and annot_format == 'birdsong-recognition-dataset': + if annot_file and annot_format == "birdsong-recognition-dataset": # we do this normalization / canonicalization after we make dataset_path # so that we can put the new annot_file inside of dataset_path, instead of # making new files elsewhere on a user's system - logger.info("The ``annot_format`` argument was set to 'birdsong-recognition-format'; " - "this format requires the audio files for their sampling rate " - "to convert onset and offset times of birdsong syllables to seconds." - "Converting this format to 'generic-seq' now with the times in seconds, " - "so that the dataset prepared by vak will not require the audio files.") + logger.info( + "The ``annot_format`` argument was set to 'birdsong-recognition-format'; " + "this format requires the audio files for their sampling rate " + "to convert onset and offset times of birdsong syllables to seconds." + "Converting this format to 'generic-seq' now with the times in seconds, " + "so that the dataset prepared by vak will not require the audio files." + ) birdsongrec = crowsetta.formats.seq.BirdsongRec.from_file(annot_file) annots = birdsongrec.to_annot() # note we point `annot_file` at a new file we're about to make - annot_file = dataset_path / f'{annot_file.stem}.converted-to-generic-seq.csv' + annot_file = ( + dataset_path / f"{annot_file.stem}.converted-to-generic-seq.csv" + ) # and we remake Annotations here so that annot_path points to this new file, not the birdsong-rec Annotation.xml annots = [ - crowsetta.Annotation(seq=annot.seq, annot_path=annot_file, notated_path=annot.notated_path) + crowsetta.Annotation( + seq=annot.seq, + annot_path=annot_file, + notated_path=annot.notated_path, + ) for annot in annots ] generic_seq = crowsetta.formats.seq.GenericSeq(annots=annots) generic_seq.to_file(annot_file) # and we now change `annot_format` as well. Both these will get passed to io.prep_spectrogram_dataset - annot_format = 'generic-seq' + annot_format = "generic-seq" # NOTE we set up logging here (instead of cli) so the prep log is included in the dataset config_logging_for_cli( - log_dst=dataset_path, - log_stem="prep", - level="INFO", - force=True + log_dst=dataset_path, log_stem="prep", level="INFO", force=True ) log_version(logger) - dataset_csv_path = dataset_df_helper.get_dataset_csv_path(dataset_path, data_dir_name, timenow) - logger.info( - f"Will prepare dataset as directory: {dataset_path}" + dataset_csv_path = dataset_df_helper.get_dataset_csv_path( + dataset_path, data_dir_name, timenow ) + logger.info(f"Will prepare dataset as directory: {dataset_path}") # ---- actually make the dataset ----------------------------------------------------------------------------------- dataset_df, shape = prep_unit_dataset( @@ -236,7 +246,9 @@ def prep_parametric_umap_dataset( "zero for test_dur (and val_dur, if a validation set will be used)" ) - if all([dur is None for dur in (train_dur, val_dur, test_dur)]) or purpose in ( + if all( + [dur is None for dur in (train_dur, val_dur, test_dur)] + ) or purpose in ( "eval", "predict", ): @@ -268,15 +280,19 @@ def prep_parametric_umap_dataset( # ideally we would just say split=purpose in call to add_split_col, but # we have to special case, because "eval" looks for a 'test' split (not an "eval" split) if purpose == "eval": - split_name = "test" # 'split_name' to avoid name clash with split package + split_name = ( + "test" # 'split_name' to avoid name clash with split package + ) elif purpose == "predict": split_name = "predict" - dataset_df = dataset_df_helper.add_split_col(dataset_df, split=split_name) + dataset_df = dataset_df_helper.add_split_col( + dataset_df, split=split_name + ) # ---- create and save labelmap ------------------------------------------------------------------------------------ # we do this before creating array files since we need to load the labelmap to make frame label vectors - if purpose != 'predict': + if purpose != "predict": # TODO: add option to generate predict using existing dataset, so we can get labelmap from it labelmap = labels.to_map(labelset, map_unlabeled=False) logger.info( @@ -295,7 +311,7 @@ def prep_parametric_umap_dataset( purpose, ) # - # # ---- if purpose is learncurve, additionally prep splits for that ------------------------------------------------- + # ---- if purpose is learncurve, additionally prep splits for that ----------------------------------------------- # if purpose == 'learncurve': # dataset_df = make_learncurve_splits_from_dataset_df( # dataset_df, @@ -309,9 +325,7 @@ def prep_parametric_umap_dataset( # ) # ---- save csv file that captures provenance of source data ------------------------------------------------------- - logger.info( - f"Saving dataset csv file: {dataset_csv_path}" - ) + logger.info(f"Saving dataset csv file: {dataset_csv_path}") dataset_df.to_csv( dataset_csv_path, index=False ) # index is False to avoid having "Unnamed: 0" column when loading diff --git a/src/vak/prep/prep.py b/src/vak/prep/prep.py index 277855a34..b9aebbdbc 100644 --- a/src/vak/prep/prep.py +++ b/src/vak/prep/prep.py @@ -1,12 +1,9 @@ import logging import pathlib -from . import ( - constants, -) -from .parametric_umap import prep_parametric_umap_dataset +from . import constants from .frame_classification import prep_frame_classification_dataset - +from .parametric_umap import prep_parametric_umap_dataset logger = logging.getLogger(__name__) @@ -25,7 +22,7 @@ def prep( labelset: set | None = None, audio_dask_bag_kwargs: dict | None = None, train_dur: int | None = None, - val_dur: int | None =None, + val_dur: int | None = None, test_dur: int | None = None, train_set_durs: list[float] | None = None, num_replicates: int | None = None, @@ -171,7 +168,7 @@ def prep( f"Value for ``input_type`` was: {input_type}" ) - if input_type == 'audio' and spect_format is not None: + if input_type == "audio" and spect_format is not None: raise ValueError( f"``input_type`` was set to 'audio' but a ``spect_format`` was specified: '{spect_format}'.\n" f"Please only provide a ``spect_format`` argument when the input type to the neural network " @@ -191,7 +188,7 @@ def prep( # we have to use an if-else here since args may vary across dataset prep functions # but we still define DATASET_TYPE_FUNC_MAP in vak.prep.constants # so that the mapping is made explicit in the code - if dataset_type == 'frame classification': + if dataset_type == "frame classification": dataset_df, dataset_path = prep_frame_classification_dataset( data_dir, input_type, @@ -236,6 +233,4 @@ def prep( else: # this is in case a dataset type is written wrong # in the if-else statements above, we want to error loudly - raise ValueError( - f"Unrecognized dataset type: {dataset_type}" - ) + raise ValueError(f"Unrecognized dataset type: {dataset_type}") diff --git a/src/vak/prep/sequence_dataset.py b/src/vak/prep/sequence_dataset.py index 1a98cf83c..11ec2df86 100644 --- a/src/vak/prep/sequence_dataset.py +++ b/src/vak/prep/sequence_dataset.py @@ -35,9 +35,7 @@ def where_unlabeled_segments(dataset_df: pd.DataFrame) -> npt.NDArray: has_unlabeled_list = [] for annot, duration in zip(annots, durations): - has_unlabeled_list.append( - annotation.has_unlabeled(annot, duration) - ) + has_unlabeled_list.append(annotation.has_unlabeled(annot, duration)) return np.array(has_unlabeled_list).astype(bool) @@ -69,6 +67,4 @@ def has_unlabeled_segments(dataset_df: pd.DataFrame) -> bool: # np.any returns an instance of # and ` is True == False`. # Not sure if this is numpy version specific - return bool( - np.any(where_unlabeled_segments(dataset_df) - )) + return bool(np.any(where_unlabeled_segments(dataset_df))) diff --git a/src/vak/prep/spectrogram_dataset/__init__.py b/src/vak/prep/spectrogram_dataset/__init__.py index 5b79863bc..7e1c56b1b 100644 --- a/src/vak/prep/spectrogram_dataset/__init__.py +++ b/src/vak/prep/spectrogram_dataset/__init__.py @@ -1,3 +1,8 @@ """Functions for preparing a dataset for neural network models from a dataset of spectrograms.""" from .prep import prep_spectrogram_dataset + + +__all__ = [ + "prep_spectrogram_dataset", +] diff --git a/src/vak/prep/spectrogram_dataset/audio_helper.py b/src/vak/prep/spectrogram_dataset/audio_helper.py index 3ef3342c2..2c84a7f2a 100644 --- a/src/vak/prep/spectrogram_dataset/audio_helper.py +++ b/src/vak/prep/spectrogram_dataset/audio_helper.py @@ -1,22 +1,20 @@ from __future__ import annotations + import logging import os from pathlib import Path -import numpy as np import dask.bag as db +import numpy as np from dask.diagnostics import ProgressBar -from ... import ( - config, -) +from ... import config from ...common import constants, files from ...common.annotation import map_annotated_to_annot from ...common.converters import labelset_to_set from ...config.spect_params import SpectParamsConfig from .spect import spectrogram - logger = logging.getLogger(__name__) @@ -120,7 +118,9 @@ def make_spectrogram_files_from_audio_files( ) if all([arg is None for arg in (audio_dir, audio_files, audio_annot_map)]): - raise ValueError("must specify one of: audio_dir, audio_files, audio_annot_map") + raise ValueError( + "must specify one of: audio_dir, audio_files, audio_annot_map" + ) if audio_dir and audio_files: raise ValueError( @@ -185,7 +185,9 @@ def make_spectrogram_files_from_audio_files( audio_files = files_from_dir(audio_dir, audio_format) if annot_list: - audio_annot_map = map_annotated_to_annot(audio_files, annot_list, annot_format) + audio_annot_map = map_annotated_to_annot( + audio_files, annot_list, annot_format + ) logger.info("creating array files with spectrograms") @@ -233,7 +235,9 @@ def _spect_file(audio_file): spect_params.audio_path_key: str(audio_file), } basename = os.path.basename(audio_file) - npz_fname = os.path.join(os.path.normpath(output_dir), basename + ".spect.npz") + npz_fname = os.path.join( + os.path.normpath(output_dir), basename + ".spect.npz" + ) np.savez(npz_fname, **spect_dict) return npz_fname diff --git a/src/vak/prep/spectrogram_dataset/prep.py b/src/vak/prep/spectrogram_dataset/prep.py index 99609c6ed..323ddbd56 100644 --- a/src/vak/prep/spectrogram_dataset/prep.py +++ b/src/vak/prep/spectrogram_dataset/prep.py @@ -1,18 +1,17 @@ from __future__ import annotations -from datetime import datetime import logging import pathlib +from datetime import datetime import attrs import crowsetta import pandas as pd -from . import audio_helper, spect_helper -from ...config.spect_params import SpectParamsConfig from ...common import annotation from ...common.converters import expanded_user_path, labelset_to_set - +from ...config.spect_params import SpectParamsConfig +from . import audio_helper, spect_helper logger = logging.getLogger(__name__) @@ -110,7 +109,9 @@ def prep_spectrogram_dataset( if spect_output_dir: spect_output_dir = expanded_user_path(spect_output_dir) if not spect_output_dir.is_dir(): - raise NotADirectoryError(f"spect_output_dir not found: {spect_output_dir}") + raise NotADirectoryError( + f"spect_output_dir not found: {spect_output_dir}" + ) else: spect_output_dir = data_dir @@ -125,7 +126,10 @@ def prep_spectrogram_dataset( annot_dir=data_dir, annot_format=annot_format ) scribe = crowsetta.Transcriber(format=annot_format) - annot_list = [scribe.from_file(annot_file).to_annot() for annot_file in annot_files] + annot_list = [ + scribe.from_file(annot_file).to_annot() + for annot_file in annot_files + ] else: scribe = crowsetta.Transcriber(format=annot_format) annot_list = scribe.from_file(annot_file).to_annot() @@ -168,7 +172,9 @@ def prep_spectrogram_dataset( "spect_output_dir": spect_output_dir, } - if spect_files: # because we just made them, and put them in spect_output_dir + if ( + spect_files + ): # because we just made them, and put them in spect_output_dir make_dataframe_kwargs["spect_files"] = spect_files logger.info( f"creating dataset from spectrogram files in: {spect_output_dir}", @@ -179,11 +185,18 @@ def prep_spectrogram_dataset( f"creating dataset from spectrogram files in: {data_dir}", ) - if spect_params: # get relevant keys for accessing arrays from array files + if spect_params: # get relevant keys for accessing arrays from array files if isinstance(spect_params, SpectParamsConfig): spect_params = attrs.asdict(spect_params) - for key in ['freqbins_key', 'timebins_key', 'spect_key', 'audio_path_key']: + for key in [ + "freqbins_key", + "timebins_key", + "spect_key", + "audio_path_key", + ]: make_dataframe_kwargs[key] = spect_params[key] - dataset_df = spect_helper.make_dataframe_of_spect_files(**make_dataframe_kwargs) + dataset_df = spect_helper.make_dataframe_of_spect_files( + **make_dataframe_kwargs + ) return dataset_df diff --git a/src/vak/prep/spectrogram_dataset/spect.py b/src/vak/prep/spectrogram_dataset/spect.py index 188226b99..d4d84ada0 100644 --- a/src/vak/prep/spectrogram_dataset/spect.py +++ b/src/vak/prep/spectrogram_dataset/spect.py @@ -6,9 +6,8 @@ https://github.com/timsainb/python_spectrograms_and_inversion """ import numpy as np - -from scipy.signal import butter, lfilter from matplotlib.mlab import specgram +from scipy.signal import butter, lfilter def butter_bandpass(lowcut, highcut, fs, order=5): @@ -68,12 +67,14 @@ def spectrogram( noverlap = fft_size - step_size if freq_cutoffs: - dat = butter_bandpass_filter(dat, freq_cutoffs[0], freq_cutoffs[1], samp_freq) + dat = butter_bandpass_filter( + dat, freq_cutoffs[0], freq_cutoffs[1], samp_freq + ) # below only take [:3] from return of specgram because we don't need the image - spect, freqbins, timebins = specgram(dat, fft_size, samp_freq, noverlap=noverlap)[ - :3 - ] + spect, freqbins, timebins = specgram( + dat, fft_size, samp_freq, noverlap=noverlap + )[:3] if transform_type: if transform_type == "log_spect": diff --git a/src/vak/prep/spectrogram_dataset/spect_helper.py b/src/vak/prep/spectrogram_dataset/spect_helper.py index a521def7c..125b5a8d8 100644 --- a/src/vak/prep/spectrogram_dataset/spect_helper.py +++ b/src/vak/prep/spectrogram_dataset/spect_helper.py @@ -9,15 +9,14 @@ import pathlib import dask.bag as db -from dask.diagnostics import ProgressBar import numpy as np import pandas as pd +from dask.diagnostics import ProgressBar from ...common import constants, files from ...common.annotation import map_annotated_to_annot from ...common.converters import expanded_user_path, labelset_to_set - logger = logging.getLogger(__name__) @@ -120,7 +119,7 @@ def make_dataframe_of_spect_files( f"format '{spect_format}' not recognized." ) - if spect_format == 'mat' and spect_output_dir is None: + if spect_format == "mat" and spect_output_dir is None: raise ValueError( "Must provide ``spect_output_dir`` when ``spect_format`` is '.mat'." "so that array files can be converted to npz format. " @@ -129,7 +128,9 @@ def make_dataframe_of_spect_files( ) if all([arg is None for arg in (spect_dir, spect_files, spect_annot_map)]): - raise ValueError("must specify one of: spect_dir, spect_files, spect_annot_map") + raise ValueError( + "must specify one of: spect_dir, spect_files, spect_annot_map" + ) if spect_dir and spect_files: raise ValueError( @@ -161,7 +162,9 @@ def make_dataframe_of_spect_files( "a spect_annot_map was provided, but no annot_format was specified" ) - if annot_format is not None and (annot_list is None and spect_annot_map is None): + if annot_format is not None and ( + annot_list is None and spect_annot_map is None + ): raise ValueError( "an annot_format was specified but no annot_list or spect_annot_map was provided" ) @@ -172,23 +175,35 @@ def make_dataframe_of_spect_files( if spect_output_dir: spect_output_dir = expanded_user_path(spect_output_dir) if not spect_output_dir.is_dir(): - raise NotADirectoryError(f"spect_output_dir not found: {spect_output_dir}") + raise NotADirectoryError( + f"spect_output_dir not found: {spect_output_dir}" + ) # ---- get a list of spectrogram files + associated annotation files ----------------------------------------------- if spect_dir: # then get spect_files from that dir # note we already validated format above - spect_files = sorted(pathlib.Path(spect_dir).glob(f"**/*{spect_format}")) + spect_files = sorted( + pathlib.Path(spect_dir).glob(f"**/*{spect_format}") + ) if spect_files: # (or if we just got them from spect_dir) if annot_list: - spect_annot_map = map_annotated_to_annot(spect_files, annot_list, annot_format, annotated_ext=spect_ext) + spect_annot_map = map_annotated_to_annot( + spect_files, annot_list, annot_format, annotated_ext=spect_ext + ) else: # no annotation, so map spectrogram files to None - spect_annot_map = dict((spect_path, None) for spect_path in spect_files) + spect_annot_map = dict( + (spect_path, None) for spect_path in spect_files + ) # use labelset if supplied, to filter - if labelset: # then assume user wants to filter out files where annotation has labels not in labelset - for spect_path, annot in list(spect_annot_map.items()): # `list` so we can pop from dict without RuntimeError + if ( + labelset + ): # then assume user wants to filter out files where annotation has labels not in labelset + for spect_path, annot in list( + spect_annot_map.items() + ): # `list` so we can pop from dict without RuntimeError annot_labelset = set(annot.seq.labels) # below, set(labels_mapping) is a set of that dict's keys if not annot_labelset.issubset(set(labelset)): @@ -233,7 +248,7 @@ def _to_record(spect_annot_tuple): spect_dur = spect_dict[spect_key].shape[-1] * timebin_dur if audio_path_key in spect_dict: audio_path = spect_dict[audio_path_key] - if type(audio_path) == np.ndarray: + if isinstance(audio_path, np.ndarray): # (because everything stored in .npz has to be in an ndarray) audio_path = audio_path.tolist() else: @@ -242,14 +257,16 @@ def _to_record(spect_annot_tuple): # (or an error) audio_path = files.spect.find_audio_fname(spect_path) - if spect_format == 'mat': + if spect_format == "mat": # convert to .npz and save in spect_output_dir spect_dict_npz = { - 's': spect_dict[spect_key], - 't': spect_dict[timebins_key], - 'f': spect_dict[freqbins_key] + "s": spect_dict[spect_key], + "t": spect_dict[timebins_key], + "f": spect_dict[freqbins_key], } - spect_path = spect_output_dir / (pathlib.Path(spect_path).stem + ".npz") + spect_path = spect_output_dir / ( + pathlib.Path(spect_path).stem + ".npz" + ) np.savez(spect_path, **spect_dict_npz) if annot is not None: @@ -268,12 +285,15 @@ def abspath(a_path): abspath(audio_path), abspath(spect_path), abspath(annot_path), - annot_format if annot_format else constants.NO_ANNOTATION_FORMAT, + annot_format + if annot_format + else constants.NO_ANNOTATION_FORMAT, spect_dur, timebin_dur, ] ) return record + spect_path_annot_tuples = db.from_sequence(spect_annot_map.items()) logger.info( "creating pandas.DataFrame representing dataset from spectrogram files", diff --git a/src/vak/prep/split/__init__.py b/src/vak/prep/split/__init__.py index aeea84a09..1f878c382 100644 --- a/src/vak/prep/split/__init__.py +++ b/src/vak/prep/split/__init__.py @@ -1,3 +1,9 @@ from . import algorithms - from .split import frame_classification_dataframe, unit_dataframe + + +__all__ = [ + "algorithms", + "frame_classification_dataframe", + "unit_dataframe", +] diff --git a/src/vak/prep/split/algorithms/__init__.py b/src/vak/prep/split/algorithms/__init__.py index 522d12391..bdd3b7591 100644 --- a/src/vak/prep/split/algorithms/__init__.py +++ b/src/vak/prep/split/algorithms/__init__.py @@ -1 +1,6 @@ from .bruteforce import brute_force + + +__all__ = [ + "brute_force" +] diff --git a/src/vak/prep/split/algorithms/bruteforce.py b/src/vak/prep/split/algorithms/bruteforce.py index 425afcd69..f25190351 100644 --- a/src/vak/prep/split/algorithms/bruteforce.py +++ b/src/vak/prep/split/algorithms/bruteforce.py @@ -1,4 +1,5 @@ from __future__ import annotations + import logging import random from typing import Union @@ -25,35 +26,37 @@ def validate_labels(labels: list[np.array], labelset: set) -> None: if uniq_labels < labelset: missing = labelset - uniq_labels raise ValueError( - f'Unable to split using this labelset: {labelset}. ' - f'The following classes of label do not appear in the list of label arrays: {missing}\n' - 'To fix, either remove those classes from the labelset, ' - 'or add vocalizations to the dataset containing the missing labels.' + f"Unable to split using this labelset: {labelset}. " + f"The following classes of label do not appear in the list of label arrays: {missing}\n" + "To fix, either remove those classes from the labelset, " + "or add vocalizations to the dataset containing the missing labels." ) elif uniq_labels > labelset: extra = uniq_labels - labelset raise ValueError( - f'Unable to split using this labelset: {labelset}. ' - f'The following classes of label that are not in labelset are found ' - f'in the list of label arrays: {extra}\n' - 'To fix, either add these classes to the labelset, ' - 'or remove the vocalizations from the dataset that contain these labels.' + f"Unable to split using this labelset: {labelset}. " + f"The following classes of label that are not in labelset are found " + f"in the list of label arrays: {extra}\n" + "To fix, either add these classes to the labelset, " + "or remove the vocalizations from the dataset that contain these labels." ) elif uniq_labels & labelset == set(): raise ValueError( - f'Unable to split using this labelset: {labelset}. ' - f'None of the label classes are found in the set of ' - f'unique labels from the list of label arrays: {uniq_labels}.' + f"Unable to split using this labelset: {labelset}. " + f"None of the label classes are found in the set of " + f"unique labels from the list of label arrays: {uniq_labels}." ) -def brute_force(durs: list[float], - labels: list[np.ndarray], - labelset: set, - train_dur: Union[int, float], - val_dur: Union[int, float], - test_dur: Union[int, float], - max_iter: int = 5000) -> (list[int], list[int], list[int]): +def brute_force( + durs: list[float], + labels: list[np.ndarray], + labelset: set, + train_dur: Union[int, float], + val_dur: Union[int, float], + test_dur: Union[int, float], + max_iter: int = 5000, +) -> (list[int], list[int], list[int]): """Generate indices that split a dataset into separate training, validation, and test subsets. @@ -143,20 +146,31 @@ def brute_force(durs: list[float], # when making `split_inds`, "initialize" the dict with all split names, by using target_split_durs # so we don't get an error when indexing into dict in return statement below - split_inds = {split_name: [] for split_name in target_split_durs.keys()} - total_split_durs = {split_name: 0 for split_name in target_split_durs.keys()} - split_labelsets = {split_name: set() for split_name in target_split_durs.keys()} + split_inds = { + split_name: [] for split_name in target_split_durs.keys() + } + total_split_durs = { + split_name: 0 for split_name in target_split_durs.keys() + } + split_labelsets = { + split_name: set() for split_name in target_split_durs.keys() + } # list of split 'choices' we use when randomly adding indices to splits choice = [] for split_name in target_split_durs.keys(): - if target_split_durs[split_name] > 0 or target_split_durs[split_name] == -1: + if ( + target_split_durs[split_name] > 0 + or target_split_durs[split_name] == -1 + ): choice.append(split_name) # ---- make sure each split has at least one instance of each label -------------------------------------------- for label_from_labelset in sorted(labelset): label_inds = [ - ind for ind in durs_labels_inds if label_from_labelset in labels[ind] + ind + for ind in durs_labels_inds + if label_from_labelset in labels[ind] ] random.shuffle(label_inds) @@ -169,9 +183,9 @@ def brute_force(durs: list[float], ind = label_inds.pop() split_inds[split_name].append(ind) total_split_durs[split_name] += durs[ind] - split_labelsets[split_name] = split_labelsets[split_name].union( - set(labels[ind]) - ) + split_labelsets[split_name] = split_labelsets[ + split_name + ].union(set(labels[ind])) durs_labels_inds.remove(ind) except IndexError: if len(label_inds) == 0: @@ -188,7 +202,8 @@ def brute_force(durs: list[float], for split_name in target_split_durs.keys(): if ( target_split_durs[split_name] > 0 - and total_split_durs[split_name] >= target_split_durs[split_name] + and total_split_durs[split_name] + >= target_split_durs[split_name] ): choice.remove(split_name) @@ -224,7 +239,8 @@ def brute_force(durs: list[float], total_split_durs[split_name] += durs[ind] if ( target_split_durs[split_name] > 0 - and total_split_durs[split_name] >= target_split_durs[split_name] + and total_split_durs[split_name] + >= target_split_durs[split_name] ): choice.remove(split_name) elif target_split_durs[split_name] == -1: @@ -239,7 +255,10 @@ def brute_force(durs: list[float], if len(choice) < 1: # list is empty, we popped off all the choices for split_name in target_split_durs.keys(): if target_split_durs[split_name] > 0: - if total_split_durs[split_name] < target_split_durs[split_name]: + if ( + total_split_durs[split_name] + < target_split_durs[split_name] + ): raise ValueError( "Loop to find splits completed, " f"but total duration of '{split_name}' split, " @@ -264,7 +283,9 @@ def brute_force(durs: list[float], or target_split_durs[split_name] == -1 ): split_labels = [ - label for ind in split_inds[split_name] for label in labels[ind] + label + for ind in split_inds[split_name] + for label in labels[ind] ] split_labelset = set(split_labels) if split_labelset != set(labelset): @@ -285,7 +306,8 @@ def brute_force(durs: list[float], continue split_inds = { - split_name: (inds if inds else None) for split_name, inds in split_inds.items() + split_name: (inds if inds else None) + for split_name, inds in split_inds.items() } return split_inds["train"], split_inds["val"], split_inds["test"] diff --git a/src/vak/prep/split/algorithms/validate.py b/src/vak/prep/split/algorithms/validate.py index 4303bff38..99befd6c6 100644 --- a/src/vak/prep/split/algorithms/validate.py +++ b/src/vak/prep/split/algorithms/validate.py @@ -51,7 +51,11 @@ def validate_split_durations(train_dur, val_dur, test_dur, dataset_dur): ) if not all( - [dur >= 0 or dur == -1 for dur in split_durs.values() if dur is not None] + [ + dur >= 0 or dur == -1 + for dur in split_durs.values() + if dur is not None + ] ): raise ValueError( "all durations for split must be real non-negative number or " @@ -69,7 +73,9 @@ def validate_split_durations(train_dur, val_dur, test_dur, dataset_dur): split_durs[split_name] = 0 if -1 in split_durs.values(): - total_other_splits_dur = sum([dur for dur in split_durs.values() if dur != -1]) + total_other_splits_dur = sum( + [dur for dur in split_durs.values() if dur != -1] + ) if total_other_splits_dur > dataset_dur: raise ValueError( diff --git a/src/vak/prep/split/split.py b/src/vak/prep/split/split.py index 4a2407c42..a504406b1 100644 --- a/src/vak/prep/split/split.py +++ b/src/vak/prep/split/split.py @@ -8,10 +8,9 @@ import numpy as np import pandas as pd +from ...common.labels import from_df as labels_from_df from .algorithms import brute_force from .algorithms.validate import validate_split_durations -from ...common.labels import from_df as labels_from_df - logger = logging.getLogger(__name__) @@ -95,8 +94,12 @@ def train_test_dur_split_inds( def frame_classification_dataframe( - dataset_df: pd.DataFrame, dataset_path: str | pathlib.Path, labelset: set, - train_dur: float | None = None, test_dur: float | None = None, val_dur: float | None = None + dataset_df: pd.DataFrame, + dataset_path: str | pathlib.Path, + labelset: set, + train_dur: float | None = None, + test_dur: float | None = None, + val_dur: float | None = None, ): """Create datasets splits from a dataframe representing a frame classification dataset. @@ -159,8 +162,12 @@ def frame_classification_dataframe( # start off with all elements set to 'None' # so we don't have to change any that are not assigned to one of the subsets to 'None' after - split_col = np.asarray(["None" for _ in range(len(dataset_df))], dtype="object") - split_zip = zip(["train", "val", "test"], [train_inds, val_inds, test_inds]) + split_col = np.asarray( + ["None" for _ in range(len(dataset_df))], dtype="object" + ) + split_zip = zip( + ["train", "val", "test"], [train_inds, val_inds, test_inds] + ) for split_name, split_inds in split_zip: if split_inds is not None: split_col[split_inds] = split_name @@ -172,8 +179,12 @@ def frame_classification_dataframe( def unit_dataframe( - dataset_df: pd.DataFrame, dataset_path: str | pathlib.Path, labelset: set, - train_dur: float | None = None, test_dur: float | None = None, val_dur: float | None = None + dataset_df: pd.DataFrame, + dataset_path: str | pathlib.Path, + labelset: set, + train_dur: float | None = None, + test_dur: float | None = None, + val_dur: float | None = None, ): """Create datasets splits from a dataframe representing a unit dataset. @@ -223,9 +234,7 @@ def unit_dataframe( dataset_df = ( dataset_df.copy() ) # don't want this function to have unexpected side effects, so return a copy - labels = [ - np.array([label]) for label in dataset_df.label.values - ] + labels = [np.array([label]) for label in dataset_df.label.values] durs = dataset_df["duration"].values train_inds, val_inds, test_inds = train_test_dur_split_inds( @@ -239,8 +248,12 @@ def unit_dataframe( # start off with all elements set to 'None' # so we don't have to change any that are not assigned to one of the subsets to 'None' after - split_col = np.asarray(["None" for _ in range(len(dataset_df))], dtype="object") - split_zip = zip(["train", "val", "test"], [train_inds, val_inds, test_inds]) + split_col = np.asarray( + ["None" for _ in range(len(dataset_df))], dtype="object" + ) + split_zip = zip( + ["train", "val", "test"], [train_inds, val_inds, test_inds] + ) for split_name, split_inds in split_zip: if split_inds is not None: split_col[split_inds] = split_name @@ -248,4 +261,4 @@ def unit_dataframe( # add split column to dataframe dataset_df["split"] = split_col - return dataset_df \ No newline at end of file + return dataset_df diff --git a/src/vak/prep/unit_dataset/__init__.py b/src/vak/prep/unit_dataset/__init__.py index d2e934e28..3f1a73469 100644 --- a/src/vak/prep/unit_dataset/__init__.py +++ b/src/vak/prep/unit_dataset/__init__.py @@ -1,2 +1,7 @@ from . import unit_dataset from .unit_dataset import prep_unit_dataset + +__all__ = [ + "prep_unit_dataset", + "unit_dataset" +] diff --git a/src/vak/prep/unit_dataset/unit_dataset.py b/src/vak/prep/unit_dataset/unit_dataset.py index 31464d1a8..457c6c12f 100644 --- a/src/vak/prep/unit_dataset/unit_dataset.py +++ b/src/vak/prep/unit_dataset/unit_dataset.py @@ -10,17 +10,16 @@ import crowsetta import dask import dask.delayed -from dask.diagnostics import ProgressBar import numpy as np import numpy.typing as npt import pandas as pd +from dask.diagnostics import ProgressBar from ...common import annotation, constants from ...common.converters import expanded_user_path, labelset_to_set from ..spectrogram_dataset.audio_helper import files_from_dir from ..spectrogram_dataset.spect import spectrogram - logger = logging.getLogger(__name__) @@ -36,6 +35,7 @@ class Segment: The dataset including metadata is saved as a csv file where these attributes become the columns. """ + data: npt.NDArray samplerate: int onset_s: float @@ -49,7 +49,10 @@ class Segment: @dask.delayed def get_segment_list( - audio_path: str, annot: crowsetta.Annotation, audio_format: str, context_s: float = 0.005 + audio_path: str, + annot: crowsetta.Annotation, + audio_format: str, + context_s: float = 0.005, ) -> list[Segment]: """Get a list of :class:`Segment` instances, given the path to an audio file and an annotation that indicates @@ -77,26 +80,40 @@ def get_segment_list( segments : list A :class:`list` of :class:`Segment` instances. """ - data, samplerate = constants.AUDIO_FORMAT_FUNC_MAP[audio_format](audio_path) - sample_dur = 1. / samplerate + data, samplerate = constants.AUDIO_FORMAT_FUNC_MAP[audio_format]( + audio_path + ) + sample_dur = 1.0 / samplerate segments = [] - for onset_s, offset_s, label in zip(annot.seq.onsets_s, annot.seq.offsets_s, annot.seq.labels): + for onset_s, offset_s, label in zip( + annot.seq.onsets_s, annot.seq.offsets_s, annot.seq.labels + ): onset_s -= context_s offset_s += context_s onset_ind = int(np.floor(onset_s * samplerate)) offset_ind = int(np.ceil(offset_s * samplerate)) - segment_data = data[onset_ind : offset_ind + 1] + segment_data = data[onset_ind: offset_ind + 1] segment_dur = segment_data.shape[-1] * sample_dur segment = Segment( - segment_data, samplerate, onset_s, offset_s, label, sample_dur, segment_dur, audio_path, annot.annot_path + segment_data, + samplerate, + onset_s, + offset_s, + label, + sample_dur, + segment_dur, + audio_path, + annot.annot_path, ) segments.append(segment) return segments -def spectrogram_from_segment(segment: Segment, spect_params: dict) -> npt.NDArray: +def spectrogram_from_segment( + segment: Segment, spect_params: dict +) -> npt.NDArray: """Compute a spectrogram given a :class:`Segment` instance. Parameters @@ -127,12 +144,15 @@ class SpectToSave: Used by :func:`save_spect`. """ + spect: npt.NDArray ind: int audio_path: str -def save_spect(spect_to_save: SpectToSave, output_dir: str | pathlib.Path) -> str: +def save_spect( + spect_to_save: SpectToSave, output_dir: str | pathlib.Path +) -> str: """Save a spectrogram array to an npy file. The filename is build from the attributes of ``spect_to_save``, @@ -148,8 +168,13 @@ def save_spect(spect_to_save: SpectToSave, output_dir: str | pathlib.Path) -> st npy_path : str Path to npy file containing spectrogram inside ``output_dir`` """ - basename = os.path.basename(spect_to_save.audio_path) + f"-segment-{spect_to_save.ind}" - npy_path = os.path.join(os.path.normpath(output_dir), basename + ".spect.npy") + basename = ( + os.path.basename(spect_to_save.audio_path) + + f"-segment-{spect_to_save.ind}" + ) + npy_path = os.path.join( + os.path.normpath(output_dir), basename + ".spect.npy" + ) np.save(npy_path, spect_to_save.spect) return npy_path @@ -164,7 +189,9 @@ def abspath(a_path): # ---- make spectrograms + records for dataframe ----------------------------------------------------------------------- @dask.delayed -def make_spect_return_record(segment: Segment, ind: int, spect_params: dict, output_dir: pathlib.Path) -> tuple: +def make_spect_return_record( + segment: Segment, ind: int, spect_params: dict, output_dir: pathlib.Path +) -> tuple: """Helper function that enables parallelized creation of "records", i.e. rows for dataframe, from . Accepts a two-element tuple containing (1) a dictionary that represents a spectrogram @@ -230,14 +257,14 @@ def pad_spectrogram(record: tuple, pad_length: float) -> None: def prep_unit_dataset( - audio_format: str, - output_dir: str, - spect_params: dict, - data_dir: list | None = None, - annot_format: str | None = None, - annot_file: str | pathlib.Path | None = None, - labelset: set | None = None, - context_s: float = 0.005, + audio_format: str, + output_dir: str, + spect_params: dict, + data_dir: list | None = None, + annot_format: str | None = None, + annot_file: str | pathlib.Path | None = None, + labelset: set | None = None, + context_s: float = 0.005, ) -> pd.DataFrame: """Prepare a dataset of units from sequences, e.g., all syllables segmented out of a dataset of birdsong. @@ -284,7 +311,10 @@ def prep_unit_dataset( annot_dir=data_dir, annot_format=annot_format ) scribe = crowsetta.Transcriber(format=annot_format) - annot_list = [scribe.from_file(annot_file).to_annot() for annot_file in annot_files] + annot_list = [ + scribe.from_file(annot_file).to_annot() + for annot_file in annot_files + ] else: scribe = crowsetta.Transcriber(format=annot_format) annot_list = scribe.from_file(annot_file).to_annot() @@ -296,13 +326,19 @@ def prep_unit_dataset( annot_list = None if annot_list: - audio_annot_map = annotation.map_annotated_to_annot(audio_files, annot_list, annot_format) + audio_annot_map = annotation.map_annotated_to_annot( + audio_files, annot_list, annot_format + ) else: # no annotation, so map spectrogram files to None - audio_annot_map = dict((audio_path, None) for audio_path in audio_files) + audio_annot_map = dict( + (audio_path, None) for audio_path in audio_files + ) # use labelset, if supplied, with annotations, if any, to filter; - if labelset and annot_list: # then remove annotations with labels not in labelset + if ( + labelset and annot_list + ): # then remove annotations with labels not in labelset for audio_file, annot in list(audio_annot_map.items()): # loop in a verbose way (i.e. not a comprehension) # so we can give user warning when we skip files @@ -319,7 +355,9 @@ def prep_unit_dataset( segments = [] for audio_path, annot in audio_annot_map.items(): - segment_list = dask.delayed(get_segment_list)(audio_path, annot, audio_format, context_s) + segment_list = dask.delayed(get_segment_list)( + audio_path, annot, audio_format, context_s + ) segments.append(segment_list) logger.info( @@ -327,7 +365,9 @@ def prep_unit_dataset( ) with ProgressBar(): segments: list[list[Segment]] = dask.compute(*segments) - segments: list[Segment] = [segment for segment_list in segments for segment in segment_list] + segments: list[Segment] = [ + segment for segment_list in segments for segment in segment_list + ] # ---- make and save all spectrograms *before* padding # This is a design choice to avoid keeping all the spectrograms in memory @@ -336,10 +376,14 @@ def prep_unit_dataset( # Might be worth looking at how often typical dataset sizes in memory and whether this is really necessary. records_n_timebins_tuples = [] for ind, segment in enumerate(segments): - records_n_timebins_tuple = make_spect_return_record(segment, ind, spect_params, output_dir) + records_n_timebins_tuple = make_spect_return_record( + segment, ind, spect_params, output_dir + ) records_n_timebins_tuples.append(records_n_timebins_tuple) with ProgressBar(): - records_n_timebins_tuples: list[tuple[tuple, int]] = dask.compute(*records_n_timebins_tuples) + records_n_timebins_tuples: list[tuple[tuple, int]] = dask.compute( + *records_n_timebins_tuples + ) records, n_timebins_list = [], [] for records_n_timebins_tuple in records_n_timebins_tuples: @@ -351,14 +395,14 @@ def prep_unit_dataset( padded = [] for record in records: - padded.append( - pad_spectrogram(record, pad_length) - ) + padded.append(pad_spectrogram(record, pad_length)) with ProgressBar(): - shapes:list[tuple[int, int]] = dask.compute(*padded) + shapes: list[tuple[int, int]] = dask.compute(*padded) shape = set(shapes) - assert len(shape) == 1, f"Did not find a single unique shape for all spectrograms. Instead found: {shape}" + assert ( + len(shape) == 1 + ), f"Did not find a single unique shape for all spectrograms. Instead found: {shape}" shape = shape.pop() unit_df = pd.DataFrame.from_records(records, columns=DF_COLUMNS) diff --git a/src/vak/train/__init__.py b/src/vak/train/__init__.py index ddefa442a..3e99d4438 100644 --- a/src/vak/train/__init__.py +++ b/src/vak/train/__init__.py @@ -1,6 +1,8 @@ -from .train import * - +from . import frame_classification, parametric_umap +from .train import train __all__ = [ + "frame_classification", + "parametric_umap", "train", ] diff --git a/src/vak/train/frame_classification.py b/src/vak/train/frame_classification.py index 645236a3a..925b27d0e 100644 --- a/src/vak/train/frame_classification.py +++ b/src/vak/train/frame_classification.py @@ -1,31 +1,21 @@ """Function that trains models in the frame classification family.""" from __future__ import annotations +import datetime import json import logging import pathlib import shutil -import datetime import joblib - import pandas as pd import torch.utils.data -from .. import ( - datasets, - models, - transforms, -) +from .. import datasets, models, transforms from ..common import validators -from ..datasets.frame_classification import ( - WindowDataset, - FramesDataset -) from ..common.device import get_default as get_default_device -from ..common.paths import generate_results_dir_name_as_path from ..common.trainer import get_default_trainer - +from ..datasets.frame_classification import FramesDataset, WindowDataset logger = logging.getLogger(__name__) @@ -55,7 +45,7 @@ def train_frame_classification_model( ckpt_step: int | None = None, patience: int | None = None, device: str | None = None, - split: str = 'train', + split: str = "train", ) -> None: """Train a model from the frame classification family and save results. @@ -158,8 +148,8 @@ def train_frame_classification_model( training set to use when training models for a learning curve. """ for path, path_name in zip( - (checkpoint_path, spect_scaler_path), - ('checkpoint_path', 'spect_scaler_path'), + (checkpoint_path, spect_scaler_path), + ("checkpoint_path", "spect_scaler_path"), ): if path is not None: if not validators.is_a_file(path): @@ -176,7 +166,9 @@ def train_frame_classification_model( logger.info( f"Loading dataset from path: {dataset_path}", ) - metadata = datasets.frame_classification.Metadata.from_dataset_path(dataset_path) + metadata = datasets.frame_classification.Metadata.from_dataset_path( + dataset_path + ) dataset_csv_path = dataset_path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) # ---------------- pre-conditions ---------------------------------------------------------------------------------- @@ -210,9 +202,7 @@ def train_frame_classification_model( ) labelmap_path = dataset_path / "labelmap.json" - logger.info( - f"loading labelmap from path: {labelmap_path}" - ) + logger.info(f"loading labelmap from path: {labelmap_path}") with labelmap_path.open("r") as f: labelmap = json.load(f) # copy to new results_path @@ -220,34 +210,37 @@ def train_frame_classification_model( json.dump(labelmap, f) if spect_scaler_path is not None and normalize_spectrograms: - logger.info( - f"loading spect scaler from path: {spect_scaler_path}" - ) + logger.info(f"loading spect scaler from path: {spect_scaler_path}") spect_standardizer = joblib.load(spect_scaler_path) shutil.copy(spect_scaler_path, results_path) # get transforms just before creating datasets with them elif normalize_spectrograms and spect_scaler_path is None: logger.info( - f"no spect_scaler_path provided, not loading", + "no spect_scaler_path provided, not loading", ) logger.info("will normalize spectrograms") spect_standardizer = transforms.StandardizeSpect.fit_dataset_path( - dataset_path, split=split, + dataset_path, + split=split, + ) + joblib.dump( + spect_standardizer, results_path.joinpath("StandardizeSpect") ) - joblib.dump(spect_standardizer, results_path.joinpath("StandardizeSpect")) elif spect_scaler_path is not None and not normalize_spectrograms: - raise ValueError('spect_scaler_path provided but normalize_spectrograms was False, these options conflict') + raise ValueError( + "spect_scaler_path provided but normalize_spectrograms was False, these options conflict" + ) else: - #not normalize_spectrograms and spect_scaler_path is None: + # not normalize_spectrograms and spect_scaler_path is None: logger.info( "normalize_spectrograms is False and no spect_scaler_path was provided, " "will not standardize spectrograms", - ) + ) spect_standardizer = None if train_transform_params is None: train_transform_params = {} - train_transform_params.update({'spect_standardizer': spect_standardizer}) + train_transform_params.update({"spect_standardizer": spect_standardizer}) transform, target_transform = transforms.defaults.get_default_transform( model_name, "train", transform_kwargs=train_transform_params ) @@ -284,11 +277,9 @@ def train_frame_classification_model( if val_transform_params is None: val_transform_params = {} - val_transform_params.update({'spect_standardizer': spect_standardizer}) + val_transform_params.update({"spect_standardizer": spect_standardizer}) item_transform = transforms.defaults.get_default_transform( - model_name, - "eval", - val_transform_params + model_name, "eval", val_transform_params ) if val_dataset_params is None: val_dataset_params = {} @@ -335,9 +326,9 @@ def train_frame_classification_model( logger.info(f"training {model_name}") max_steps = num_epochs * len(train_loader) default_callback_kwargs = { - 'ckpt_root': ckpt_root, - 'ckpt_step': ckpt_step, - 'patience': patience, + "ckpt_root": ckpt_root, + "ckpt_step": ckpt_step, + "patience": patience, } trainer = get_default_trainer( max_steps=max_steps, @@ -347,19 +338,13 @@ def train_frame_classification_model( device=device, ) train_time_start = datetime.datetime.now() - logger.info( - f"Training start time: {train_time_start.isoformat()}" - ) + logger.info(f"Training start time: {train_time_start.isoformat()}") trainer.fit( model=model, train_dataloaders=train_loader, val_dataloaders=val_loader, ) train_time_stop = datetime.datetime.now() - logger.info( - f"Training stop time: {train_time_stop.isoformat()}" - ) + logger.info(f"Training stop time: {train_time_stop.isoformat()}") elapsed = train_time_stop - train_time_start - logger.info( - f"Elapsed training time: {elapsed}" - ) + logger.info(f"Elapsed training time: {elapsed}") diff --git a/src/vak/train/parametric_umap.py b/src/vak/train/parametric_umap.py index 22c2a45ea..b2326eb0e 100644 --- a/src/vak/train/parametric_umap.py +++ b/src/vak/train/parametric_umap.py @@ -1,24 +1,19 @@ """Function that trains models in the Parametric UMAP family.""" from __future__ import annotations +import datetime import logging import pathlib -import datetime import pandas as pd -import torch.utils.data import pytorch_lightning as lightning +import torch.utils.data -from .. import ( - datasets, - models, - transforms, -) +from .. import datasets, models, transforms from ..common import validators -from ..datasets.parametric_umap import ParametricUMAPDataset from ..common.device import get_default as get_default_device from ..common.paths import generate_results_dir_name_as_path - +from ..datasets.parametric_umap import ParametricUMAPDataset logger = logging.getLogger(__name__) @@ -28,50 +23,48 @@ def get_split_dur(df: pd.DataFrame, split: str) -> float: return df[df["split"] == split]["duration"].sum() -def get_trainer(max_epochs: int, - ckpt_root: str | pathlib.Path, - ckpt_step: int, - log_save_dir: str | pathlib.Path, - device: str = 'cuda', - ) -> lightning.Trainer: +def get_trainer( + max_epochs: int, + ckpt_root: str | pathlib.Path, + ckpt_step: int, + log_save_dir: str | pathlib.Path, + device: str = "cuda", +) -> lightning.Trainer: """Returns an instance of ``lightning.Trainer`` with a default set of callbacks. Used by ``vak.core`` functions.""" - if device == 'cuda': - accelerator = 'gpu' + if device == "cuda": + accelerator = "gpu" else: accelerator = None ckpt_callback = lightning.callbacks.ModelCheckpoint( dirpath=ckpt_root, - filename='checkpoint', + filename="checkpoint", every_n_train_steps=ckpt_step, save_last=True, verbose=True, ) - ckpt_callback.CHECKPOINT_NAME_LAST = 'checkpoint' - ckpt_callback.FILE_EXTENSION = '.pt' + ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint" + ckpt_callback.FILE_EXTENSION = ".pt" val_ckpt_callback = lightning.callbacks.ModelCheckpoint( monitor="val_loss", dirpath=ckpt_root, save_top_k=1, - mode='min', - filename='min-val-loss-checkpoint', + mode="min", + filename="min-val-loss-checkpoint", auto_insert_metric_name=False, - verbose=True + verbose=True, ) - val_ckpt_callback.FILE_EXTENSION = '.pt' + val_ckpt_callback.FILE_EXTENSION = ".pt" callbacks = [ ckpt_callback, val_ckpt_callback, ] - - logger = lightning.loggers.TensorBoardLogger( - save_dir=log_save_dir - ) + logger = lightning.loggers.TensorBoardLogger(save_dir=log_save_dir) trainer = lightning.Trainer( max_epochs=max_epochs, @@ -100,7 +93,7 @@ def train_parametric_umap_model( val_step: int | None = None, ckpt_step: int | None = None, device: str | None = None, - split: str = 'train', + split: str = "train", ) -> None: """Train a model from the parametric UMAP family and save results. @@ -172,8 +165,8 @@ def train_parametric_umap_model( training set to use when training models for a learning curve. """ for path, path_name in zip( - (checkpoint_path,), - ('checkpoint_path',), + (checkpoint_path,), + ("checkpoint_path",), ): if path is not None: if not validators.is_a_file(path): @@ -190,7 +183,9 @@ def train_parametric_umap_model( logger.info( f"Loading dataset from path: {dataset_path}", ) - metadata = datasets.parametric_umap.Metadata.from_dataset_path(dataset_path) + metadata = datasets.parametric_umap.Metadata.from_dataset_path( + dataset_path + ) dataset_csv_path = dataset_path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) # ---------------- pre-conditions ---------------------------------------------------------------------------------- @@ -224,10 +219,15 @@ def train_parametric_umap_model( if train_transform_params is None: train_transform_params = {} - if 'padding' not in train_transform_params and model_name == 'ConvEncoderUMAP': + if ( + "padding" not in train_transform_params + and model_name == "ConvEncoderUMAP" + ): padding = models.convencoder_umap.get_default_padding(metadata.shape) - train_transform_params['padding'] = padding - transform = transforms.defaults.get_default_transform(model_name, "train", train_transform_params) + train_transform_params["padding"] = padding + transform = transforms.defaults.get_default_transform( + model_name, "train", train_transform_params + ) if train_dataset_params is None: train_dataset_params = {} @@ -251,10 +251,17 @@ def train_parametric_umap_model( if val_step: if val_transform_params is None: val_transform_params = {} - if 'padding' not in val_transform_params and model_name == 'ConvEncoderUMAP': - padding = models.convencoder_umap.get_default_padding(metadata.shape) - val_transform_params['padding'] = padding - transform = transforms.defaults.get_default_transform(model_name, "eval", val_transform_params) + if ( + "padding" not in val_transform_params + and model_name == "ConvEncoderUMAP" + ): + padding = models.convencoder_umap.get_default_padding( + metadata.shape + ) + val_transform_params["padding"] = padding + transform = transforms.defaults.get_default_transform( + model_name, "eval", val_transform_params + ) if val_dataset_params is None: val_dataset_params = {} val_dataset = ParametricUMAPDataset.from_dataset_path( @@ -303,19 +310,13 @@ def train_parametric_umap_model( ckpt_step=ckpt_step, ) train_time_start = datetime.datetime.now() - logger.info( - f"Training start time: {train_time_start.isoformat()}" - ) + logger.info(f"Training start time: {train_time_start.isoformat()}") trainer.fit( model=model, train_dataloaders=train_loader, val_dataloaders=val_loader, ) train_time_stop = datetime.datetime.now() - logger.info( - f"Training stop time: {train_time_stop.isoformat()}" - ) + logger.info(f"Training stop time: {train_time_stop.isoformat()}") elapsed = train_time_stop - train_time_start - logger.info( - f"Elapsed training time: {elapsed}" - ) + logger.info(f"Elapsed training time: {elapsed}") diff --git a/src/vak/train/train.py b/src/vak/train/train.py index 69f38efea..c25046827 100644 --- a/src/vak/train/train.py +++ b/src/vak/train/train.py @@ -4,13 +4,10 @@ import logging import pathlib +from .. import models +from ..common import validators from .frame_classification import train_frame_classification_model from .parametric_umap import train_parametric_umap_model -from .. import ( - models, -) -from ..common import validators - logger = logging.getLogger(__name__) @@ -35,7 +32,7 @@ def train( ckpt_step: int | None = None, patience: int | None = None, device: str | None = None, - split: str = 'train', + split: str = "train", ): """Train a model and save results. @@ -84,7 +81,7 @@ def train( Path to a checkpoint file, e.g., one generated by a previous run of ``vak.core.train``. If specified, this checkpoint will be loaded into model. - Used when continuing training. + Used when continuing training. Default is None, in which case a new model is initialized. spect_scaler_path : str, pathlib.Path path to a ``SpectScaler`` used to normalize spectrograms, @@ -146,8 +143,8 @@ def train( training set to use when training models for a learning curve. """ for path, path_name in zip( - (checkpoint_path, spect_scaler_path), - ('checkpoint_path', 'spect_scaler_path'), + (checkpoint_path, spect_scaler_path), + ("checkpoint_path", "spect_scaler_path"), ): if path is not None: if not validators.is_a_file(path): @@ -211,6 +208,4 @@ def train( split=split, ) else: - raise ValueError( - f"Model family not recognized: {model_family}" - ) + raise ValueError(f"Model family not recognized: {model_family}") diff --git a/src/vak/transforms/__init__.py b/src/vak/transforms/__init__.py index 859f020b9..2f34644c0 100644 --- a/src/vak/transforms/__init__.py +++ b/src/vak/transforms/__init__.py @@ -1,5 +1,2 @@ -from . import ( - defaults, - frame_labels -) -from .transforms import * +from . import defaults, frame_labels # noqa : F401 +from .transforms import * # noqa : F403 diff --git a/src/vak/transforms/defaults/__init__.py b/src/vak/transforms/defaults/__init__.py index f5f585b69..d82644b60 100644 --- a/src/vak/transforms/defaults/__init__.py +++ b/src/vak/transforms/defaults/__init__.py @@ -1,12 +1,4 @@ -from . import ( - frame_classification, - parametric_umap, -) +from . import frame_classification, parametric_umap from .get import get_default_transform - -__all__ = [ - "get_default_transform", - "frame_classification", - "parametric_umap" -] +__all__ = ["get_default_transform", "frame_classification", "parametric_umap"] diff --git a/src/vak/transforms/defaults/frame_classification.py b/src/vak/transforms/defaults/frame_classification.py index 40f508516..2b5733b18 100644 --- a/src/vak/transforms/defaults/frame_classification.py +++ b/src/vak/transforms/defaults/frame_classification.py @@ -41,7 +41,9 @@ def __init__( vak_transforms.AddChannel(), ] ) - self.source_transform = torchvision.transforms.Compose(source_transform) + self.source_transform = torchvision.transforms.Compose( + source_transform + ) self.annot_transform = vak_transforms.ToLongTensor() def __call__(self, source, annot, spect_path=None): @@ -77,7 +79,9 @@ def __init__( channel_dim=1, ): if spect_standardizer is not None: - if not isinstance(spect_standardizer, vak_transforms.StandardizeSpect): + if not isinstance( + spect_standardizer, vak_transforms.StandardizeSpect + ): raise TypeError( f"invalid type for spect_standardizer: {type(spect_standardizer)}. " "Should be an instance of vak.transforms.StandardizeSpect" @@ -145,7 +149,9 @@ def __init__( channel_dim=1, ): if spect_standardizer is not None: - if not isinstance(spect_standardizer, vak_transforms.StandardizeSpect): + if not isinstance( + spect_standardizer, vak_transforms.StandardizeSpect + ): raise TypeError( f"invalid type for spect_standardizer: {type(spect_standardizer)}. " "Should be an instance of vak.transforms.StandardizeSpect" @@ -191,7 +197,7 @@ def __call__(self, frames, source_path=None): def get_default_frame_classification_transform( - mode: str, transform_kwargs: dict + mode: str, transform_kwargs: dict ) -> tuple[Callable, Callable] | Callable: """Get default transform for frame classification model. @@ -220,7 +226,7 @@ def get_default_frame_classification_transform( ------- """ - spect_standardizer = transform_kwargs.get('spect_standardizer', None) + spect_standardizer = transform_kwargs.get("spect_standardizer", None) # regardless of mode, transform always starts with StandardizeSpect, if used if spect_standardizer is not None: if not isinstance(spect_standardizer, vak_transforms.StandardizeSpect): @@ -249,18 +255,22 @@ def get_default_frame_classification_transform( elif mode == "predict": item_transform = PredictItemTransform( spect_standardizer=spect_standardizer, - window_size=transform_kwargs['window_size'], - padval=transform_kwargs.get('padval', 0.0), - return_padding_mask=transform_kwargs.get('return_padding_mask', True), + window_size=transform_kwargs["window_size"], + padval=transform_kwargs.get("padval", 0.0), + return_padding_mask=transform_kwargs.get( + "return_padding_mask", True + ), ) return item_transform elif mode == "eval": item_transform = EvalItemTransform( spect_standardizer=spect_standardizer, - window_size=transform_kwargs['window_size'], - padval=transform_kwargs.get('padval', 0.0), - return_padding_mask=transform_kwargs.get('return_padding_mask', True), + window_size=transform_kwargs["window_size"], + padval=transform_kwargs.get("padval", 0.0), + return_padding_mask=transform_kwargs.get( + "return_padding_mask", True + ), ) return item_transform else: diff --git a/src/vak/transforms/defaults/get.py b/src/vak/transforms/defaults/get.py index d86736e7e..0851d515c 100644 --- a/src/vak/transforms/defaults/get.py +++ b/src/vak/transforms/defaults/get.py @@ -1,11 +1,8 @@ """Helper function that gets default transforms for a model.""" from __future__ import annotations -from . import ( - frame_classification, - parametric_umap, -) from ... import models +from . import frame_classification, parametric_umap def get_default_transform( @@ -44,4 +41,6 @@ def get_default_transform( ) elif model_family == "ParametricUMAPModel": - return parametric_umap.get_default_parametric_umap_transform(transform_kwargs) + return parametric_umap.get_default_parametric_umap_transform( + transform_kwargs + ) diff --git a/src/vak/transforms/defaults/parametric_umap.py b/src/vak/transforms/defaults/parametric_umap.py index 345332838..07ededc64 100644 --- a/src/vak/transforms/defaults/parametric_umap.py +++ b/src/vak/transforms/defaults/parametric_umap.py @@ -6,7 +6,9 @@ from .. import transforms as vak_transforms -def get_default_parametric_umap_transform(transform_kwargs) -> torchvision.transforms.Compose: +def get_default_parametric_umap_transform( + transform_kwargs, +) -> torchvision.transforms.Compose: """Get default transform for frame classification model. Parameters @@ -21,8 +23,8 @@ def get_default_parametric_umap_transform(transform_kwargs) -> torchvision.trans vak_transforms.ToFloatTensor(), vak_transforms.AddChannel(), ] - if 'padding' in transform_kwargs: + if "padding" in transform_kwargs: transforms.append( - torchvision.transforms.Pad(transform_kwargs['padding']) + torchvision.transforms.Pad(transform_kwargs["padding"]) ) return torchvision.transforms.Compose(transforms) diff --git a/src/vak/transforms/frame_labels/__init__.py b/src/vak/transforms/frame_labels/__init__.py index 69911c255..e9f503804 100644 --- a/src/vak/transforms/frame_labels/__init__.py +++ b/src/vak/transforms/frame_labels/__init__.py @@ -1,7 +1,2 @@ -from .functional import * -from .transforms import ( - FromSegments, - PostProcess, - ToLabels, - ToSegments, -) +from .functional import * # noqa : F401 +from .transforms import FromSegments, PostProcess, ToLabels, ToSegments # noqa : F401 diff --git a/src/vak/transforms/frame_labels/functional.py b/src/vak/transforms/frame_labels/functional.py index afaa2fd1b..ba37511a8 100644 --- a/src/vak/transforms/frame_labels/functional.py +++ b/src/vak/transforms/frame_labels/functional.py @@ -1,7 +1,7 @@ """Functional forms of transformations related to frame labels, i.e., vectors where each element represents -a label for a frame, either a single sample in audio +a label for a frame, either a single sample in audio or a single time bin from a spectrogram. This module is structured as followed: @@ -25,24 +25,25 @@ from ...common.timebins import timebin_dur_from_vec from ...common.validators import column_or_1d, row_or_1d - __all__ = [ # keep alphabetized - 'from_segments', - 'postprocess', - 'remove_short_segments', - 'take_majority_vote', - 'to_inds_list', - 'to_labels', - 'to_segments', + "from_segments", + "postprocess", + "remove_short_segments", + "take_majority_vote", + "to_inds_list", + "to_labels", + "to_segments", ] -def from_segments(labels_int: np.ndarray, - onsets_s: np.ndarray, - offsets_s: np.ndarray, - time_bins: np.ndarray, - unlabeled_label: int = 0) -> np.ndarray: +def from_segments( + labels_int: np.ndarray, + onsets_s: np.ndarray, + offsets_s: np.ndarray, + time_bins: np.ndarray, + unlabeled_label: int = 0, +) -> np.ndarray: """Make a vector of labels for a vector of frames, given labeled segments in the form of onset times, offset times, and segment labels. @@ -68,23 +69,24 @@ def from_segments(labels_int: np.ndarray, same length as time_bins, with each element a label for each time bin """ if ( - ( - type(labels_int) == list - and not all([type(lbl) == int for lbl in labels_int]) - ) or - ( - type(labels_int) == np.ndarray - and labels_int.dtype not in [np.int8, np.int16, np.int32, np.int64] - ) + isinstance(labels_int, list) + and not all([isinstance(lbl, int) for lbl in labels_int]) + ) or ( + isinstance(labels_int, np.ndarray) + and labels_int.dtype not in [np.int8, np.int16, np.int32, np.int64] ): - raise TypeError("labels_int must be a list or numpy.ndarray of integers") + raise TypeError( + "labels_int must be a list or numpy.ndarray of integers" + ) label_vec = np.ones((time_bins.shape[-1],), dtype="int8") * unlabeled_label onset_inds = [np.argmin(np.abs(time_bins - onset)) for onset in onsets_s] - offset_inds = [np.argmin(np.abs(time_bins - offset)) for offset in offsets_s] + offset_inds = [ + np.argmin(np.abs(time_bins - offset)) for offset in offsets_s + ] for label, onset, offset in zip(labels_int, onset_inds, offset_inds): # offset_inds[ind]+1 because offset time bin is still "part of" syllable - label_vec[onset:offset + 1] = label + label_vec[onset: offset + 1] = label # noqa: E203 return label_vec @@ -138,10 +140,10 @@ def to_labels(frame_labels: np.ndarray, labelmap: dict) -> str: def to_segments( - frame_labels: np.ndarray, - labelmap: dict, - frame_times: np.ndarray, - n_decimals_trunc: int = 5 + frame_labels: np.ndarray, + labelmap: dict, + frame_times: np.ndarray, + n_decimals_trunc: int = 5, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Convert a vector of frame labels into segments in the form of onset indices, @@ -192,15 +194,22 @@ def to_segments( # handle the case when all time bins are predicted to be unlabeled # see https://github.com/NickleDave/vak/issues/383 uniq_frame_labels = np.unique(frame_labels) - if len(uniq_frame_labels) == 1 and uniq_frame_labels[0] == labelmap["unlabeled"]: + if ( + len(uniq_frame_labels) == 1 + and uniq_frame_labels[0] == labelmap["unlabeled"] + ): return None, None, None # used to find onsets/offsets below; compute here so if we fail we do so early timebin_dur = timebin_dur_from_vec(frame_times, n_decimals_trunc) - offset_inds = np.nonzero(np.diff(frame_labels, axis=0))[0] # [0] because nonzero return tuple + offset_inds = np.nonzero(np.diff(frame_labels, axis=0))[ + 0 + ] # [0] because nonzero return tuple onset_inds = offset_inds + 1 - offset_inds = np.concatenate((offset_inds, np.asarray([frame_labels.shape[0] - 1]))) + offset_inds = np.concatenate( + (offset_inds, np.asarray([frame_labels.shape[0] - 1])) + ) onset_inds = np.concatenate((np.asarray([0]), onset_inds)) labels = frame_labels[onset_inds] @@ -242,7 +251,9 @@ def to_segments( return labels, onsets_s, offsets_s -def to_inds_list(frame_labels: np.ndarray, unlabeled_label: int = 0) -> list[np.ndarray]: +def to_inds_list( + frame_labels: np.ndarray, unlabeled_label: int = 0 +) -> list[np.ndarray]: """Given a vector of frame labels, returns a list of indexing vectors, one for each labeled segment in the vector. @@ -270,11 +281,11 @@ def to_inds_list(frame_labels: np.ndarray, unlabeled_label: int = 0) -> list[np. def remove_short_segments( - frame_labels: np.ndarray, - segment_inds_list: list[np.ndarray], - timebin_dur: float, - min_segment_dur: float | int, - unlabeled_label: int = 0 + frame_labels: np.ndarray, + segment_inds_list: list[np.ndarray], + timebin_dur: float, + min_segment_dur: float | int, + unlabeled_label: int = 0, ) -> tuple[np.ndarray, list[np.ndarray]]: """Remove segments from vector of frame labels that are shorter than a specified duration. @@ -326,8 +337,9 @@ def remove_short_segments( return frame_labels, new_segment_inds_list -def take_majority_vote(frame_labels: np.ndarray, - segment_inds_list: list[np.ndarray]) -> np.ndarray: +def take_majority_vote( + frame_labels: np.ndarray, segment_inds_list: list[np.ndarray] +) -> np.ndarray: """Transform segments containing multiple labels into segments with a single label by taking a "majority vote", i.e. assign all frames in the segment the most frequently @@ -358,11 +370,11 @@ def take_majority_vote(frame_labels: np.ndarray, def postprocess( - frame_labels: np.ndarray, - timebin_dur: float, - unlabeled_label: int = 0, - min_segment_dur: float | None = None, - majority_vote: bool = False, + frame_labels: np.ndarray, + timebin_dur: float, + unlabeled_label: int = 0, + min_segment_dur: float | None = None, + majority_vote: bool = False, ) -> np.ndarray: """Apply post-processing transformations to a vector of frame labels. diff --git a/src/vak/transforms/frame_labels/transforms.py b/src/vak/transforms/frame_labels/transforms.py index 5ce2bf30a..2734b7da0 100644 --- a/src/vak/transforms/frame_labels/transforms.py +++ b/src/vak/transforms/frame_labels/transforms.py @@ -38,14 +38,17 @@ class FromSegments: Label assigned to time bins that do not have labels associated with them. Default is 0. """ + def __init__(self, unlabeled_label: int = 0): self.unlabeled_label = unlabeled_label - def __call__(self, - labels_int: np.ndarray, - onsets_s: np.ndarray, - offsets_s: np.ndarray, - time_bins: np.ndarray) -> np.ndarray: + def __call__( + self, + labels_int: np.ndarray, + onsets_s: np.ndarray, + offsets_s: np.ndarray, + time_bins: np.ndarray, + ) -> np.ndarray: """Make a vector of frame labels, given labeled segments in the form of onset times, offset times, and segment labels. @@ -67,8 +70,13 @@ def __call__(self, frame_labels : numpy.ndarray same length as time_bins, with each element a label for each time bin """ - return F.from_segments(labels_int, onsets_s, offsets_s, time_bins, - unlabeled_label=self.unlabeled_label) + return F.from_segments( + labels_int, + onsets_s, + offsets_s, + time_bins, + unlabeled_label=self.unlabeled_label, + ) class ToLabels: @@ -87,6 +95,7 @@ class ToLabels: That maps string labels to integers. The mapping is inverted to convert back to string labels. """ + def __init__(self, labelmap: dict): self.labelmap = labelmap @@ -133,16 +142,13 @@ class ToSegments: calculated from the vector of times t. Default is 5. """ - def __init__(self, - labelmap: dict, - n_decimals_trunc: int = 5 - ): + def __init__(self, labelmap: dict, n_decimals_trunc: int = 5): self.labelmap = labelmap self.n_decimals_trunc = n_decimals_trunc - def __call__(self, - frame_labels: np.ndarray, - frame_times: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + def __call__( + self, frame_labels: np.ndarray, frame_times: np.ndarray + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Convert a vector of frame labels into segments in the form of onset indices, offset indices, and labels. @@ -180,7 +186,9 @@ def __call__(self, Vector where each element is the offset in seconds of a segment. Each offset corresponds to the value at the same index in labels. """ - return F.to_segments(frame_labels, self.labelmap, frame_times, self.n_decimals_trunc) + return F.to_segments( + frame_labels, self.labelmap, frame_times, self.n_decimals_trunc + ) class PostProcess: @@ -237,19 +245,20 @@ class PostProcess: because unlabeled segments makes it possible to identify the labeled segments. Default is False. """ - def __init__(self, - timebin_dur: float, - unlabeled_label: int = 0, - min_segment_dur: float | None = None, - majority_vote: bool = False, - ): + + def __init__( + self, + timebin_dur: float, + unlabeled_label: int = 0, + min_segment_dur: float | None = None, + majority_vote: bool = False, + ): self.timebin_dur = timebin_dur self.unlabeled_label = unlabeled_label self.min_segment_dur = min_segment_dur self.majority_vote = majority_vote - def __call__(self, - frame_labels: np.ndarray) -> np.ndarray: + def __call__(self, frame_labels: np.ndarray) -> np.ndarray: """Convert vector of frame labels into labels. Parameters @@ -265,5 +274,10 @@ def __call__(self, frame_labels : numpy.ndarray Vector of frame labels after post-processing is applied. """ - return F.postprocess(frame_labels, self.timebin_dur, self.unlabeled_label, - self.min_segment_dur, self.majority_vote) + return F.postprocess( + frame_labels, + self.timebin_dur, + self.unlabeled_label, + self.min_segment_dur, + self.majority_vote, + ) diff --git a/src/vak/transforms/functional.py b/src/vak/transforms/functional.py index 8fc515f22..e2ecbfd01 100644 --- a/src/vak/transforms/functional.py +++ b/src/vak/transforms/functional.py @@ -32,7 +32,9 @@ def standardize_spect(spect, mean_freqs, std_freqs, non_zero_std): """ tfm = spect - mean_freqs[:, np.newaxis] # need axis for broadcasting # keep any stds that are zero from causing NaNs - tfm[non_zero_std, :] = tfm[non_zero_std, :] / std_freqs[non_zero_std, np.newaxis] + tfm[non_zero_std, :] = ( + tfm[non_zero_std, :] / std_freqs[non_zero_std, np.newaxis] + ) return tfm @@ -124,8 +126,8 @@ def view_as_window_batch(arr, window_width): adapted from skimage.util.view_as_blocks https://github.com/scikit-image/scikit-image/blob/f1b7cf60fb80822849129cb76269b75b8ef18db1/skimage/util/shape.py#L9 """ - if not (type(window_width) == int and window_width > 0): - raise ValueError(f"window width must be a positive integer") + if not isinstance(window_width, int) or window_width < 1: + raise ValueError(f"`window_width` must be a positive integer, but was: {window_width}") if arr.ndim == 1: window_shape = (window_width,) diff --git a/src/vak/transforms/transforms.py b/src/vak/transforms/transforms.py index 9a14acb66..08920fa89 100644 --- a/src/vak/transforms/transforms.py +++ b/src/vak/transforms/transforms.py @@ -6,7 +6,6 @@ from ..common.validators import column_or_1d from . import functional as F - __all__ = [ "AddChannel", "PadToWindow", @@ -46,20 +45,28 @@ def __init__(self, mean_freqs=None, std_freqs=None, non_zero_std=None): non_zero_std : numpy.ndarray boolean, indicates where std_freqs has non-zero values. Used to avoid divide-by-zero errors. """ - if any([arg is not None for arg in (mean_freqs, std_freqs, non_zero_std)]): + if any( + [arg is not None for arg in (mean_freqs, std_freqs, non_zero_std)] + ): mean_freqs, std_freqs, non_zero_std = ( - column_or_1d(arr) for arr in (mean_freqs, std_freqs, non_zero_std) + column_or_1d(arr) + for arr in (mean_freqs, std_freqs, non_zero_std) ) if ( len( np.unique( - [arg.shape[0] for arg in (mean_freqs, std_freqs, non_zero_std)] + [ + arg.shape[0] + for arg in (mean_freqs, std_freqs, non_zero_std) + ] ) ) != 1 ): raise ValueError( - f"mean_freqs, std_freqs, and non_zero_std must all have the same length" + "`mean_freqs`, `std_freqs`, and `non_zero_std` must all have the same length.\n" + f"`mean_freqs.shape`: {mean_freqs.shape}, `std_freqs.shape`: {std_freqs.shape}, " + f"`non_zero_std.shape`: {non_zero_std.shape}" ) self.mean_freqs = mean_freqs @@ -67,7 +74,7 @@ def __init__(self, mean_freqs=None, std_freqs=None, non_zero_std=None): self.non_zero_std = non_zero_std @classmethod - def fit_dataset_path(cls, dataset_path, split='train'): + def fit_dataset_path(cls, dataset_path, split="train"): """Returns a :class:`StandardizeSpect` instance that is fit to a split from a dataset, given the path to that dataset and the @@ -85,16 +92,18 @@ def fit_dataset_path(cls, dataset_path, split='train'): standardize_spect : StandardizeSpect Instance that has been fit to input data from split. """ - from vak.datasets.frame_classification import Metadata from vak.datasets import frame_classification + from vak.datasets.frame_classification import Metadata dataset_path = pathlib.Path(dataset_path) metadata = Metadata.from_dataset_path(dataset_path) dataset_csv_path = dataset_path / metadata.dataset_csv_filename dataset_path = dataset_csv_path.parent df = pd.read_csv(dataset_csv_path) - df = df[df['split'] == split].copy() - frames_paths = df[frame_classification.constants.FRAMES_NPY_PATH_COL_NAME].values + df = df[df["split"] == split].copy() + frames_paths = df[ + frame_classification.constants.FRAMES_NPY_PATH_COL_NAME + ].values frames = np.load(dataset_path / frames_paths[0]) # in files, spectrograms are in orientation (freq bins, time bins) @@ -149,7 +158,9 @@ def __call__(self, spect): array standardized to same scale as set of spectrograms that SpectScaler was fit with """ - if any([not hasattr(self, attr) for attr in ["mean_freqs", "std_freqs"]]): + if any( + [not hasattr(self, attr) for attr in ["mean_freqs", "std_freqs"]] + ): raise AttributeError( "SpectScaler properties are set to None," "must call fit method first to set the" @@ -157,7 +168,7 @@ def __call__(self, spect): "transform" ) - if type(spect) != np.ndarray: + if not isinstance(spect, np.ndarray): raise TypeError( f"type of spect must be numpy.ndarray but was: {type(spect)}" ) @@ -214,8 +225,8 @@ class PadToWindow: """ def __init__(self, window_size, padval=0.0, return_padding_mask=True): - if not (type(window_size) == int) or ( - type(window_size) == float and window_size.is_integer() is False + if not isinstance(window_size, int) or ( + isinstance(window_size, float) and window_size.is_integer() is False ): raise ValueError( f"window size must be an int or a whole number float;" @@ -226,7 +237,7 @@ def __init__(self, window_size, padval=0.0, return_padding_mask=True): raise TypeError( f"type for padval must be int or float but was: {type(padval)}" ) - if not type(return_padding_mask) == bool: + if not isinstance(return_padding_mask, bool): raise TypeError( "return_padding_mask must be boolean (True or False), " f"but was type {type(return_padding_mask)} with value {return_padding_mask}" @@ -275,9 +286,9 @@ class ViewAsWindowBatch: https://github.com/scikit-image/scikit-image/blob/f1b7cf60fb80822849129cb76269b75b8ef18db1/skimage/util/shape.py#L9 """ - def __init__(self, window_width): - if not (type(window_width) == int) or ( - type(window_width) == float and window_width.is_integer() is False + def __init__(self, window_width: int | float): + if not isinstance(window_width, int) or ( + isinstance(window_width, float) and window_width.is_integer() is False ): raise ValueError( f"window size must be an int or a whole number float;" @@ -355,9 +366,9 @@ class AddChannel: Default is 0, which returns a tensor with dimensions (channel, height, width). """ - def __init__(self, channel_dim=0): - if not (type(channel_dim) == int) or ( - type(channel_dim) == float and channel_dim.is_integer() is False + def __init__(self, channel_dim: int | float = 0): + if not isinstance(channel_dim, int) or ( + isinstance(channel_dim, float) and channel_dim.is_integer() is False ): raise ValueError( f"window size must be an int or a whole number float;"