diff --git a/noxfile.py b/noxfile.py index d83b8110c..afdca0793 100644 --- a/noxfile.py +++ b/noxfile.py @@ -94,7 +94,7 @@ def coverage(session) -> None: """ session.install(".[test]") session.run( - "pytest", "--cov=./", "--cov-report=xml", *session.posargs + "pytest", "--slow-last", "--cov=./", "--cov-report=xml", *session.posargs ) diff --git a/src/vak/__about__.py b/src/vak/__about__.py index 43a303008..15dd5d377 100644 --- a/src/vak/__about__.py +++ b/src/vak/__about__.py @@ -20,9 +20,7 @@ __title__ = "vak" -__summary__ = ( - "A neural network framework for researchers studying acoustic communication" -) +__summary__ = "A neural network framework for researchers studying acoustic communication" __uri__ = "https://github.com/NickleDave/vak" __version__ = "1.0.0.post2" diff --git a/src/vak/cli/predict.py b/src/vak/cli/predict.py index c0079e360..b187d805b 100644 --- a/src/vak/cli/predict.py +++ b/src/vak/cli/predict.py @@ -49,7 +49,11 @@ def predict(toml_path): checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, - timebins_key=cfg.prep.spect_params.timebins_key if cfg.prep else common.constants.TIMEBINS_KEY, + timebins_key=( + cfg.prep.spect_params.timebins_key + if cfg.prep + else common.constants.TIMEBINS_KEY + ), frames_standardizer_path=cfg.predict.frames_standardizer_path, annot_csv_filename=cfg.predict.annot_csv_filename, output_dir=cfg.predict.output_dir, diff --git a/src/vak/common/constants.py b/src/vak/common/constants.py index 044395d2b..a3183e3eb 100644 --- a/src/vak/common/constants.py +++ b/src/vak/common/constants.py @@ -61,4 +61,4 @@ VALID_SPLITS = ("predict", "test", "train", "val") -DEFAULT_BACKGROUND_LABEL = "background" \ No newline at end of file +DEFAULT_BACKGROUND_LABEL = "background" diff --git a/src/vak/common/labels.py b/src/vak/common/labels.py index e9c6d8296..dc30ed05c 100644 --- a/src/vak/common/labels.py +++ b/src/vak/common/labels.py @@ -9,7 +9,9 @@ def to_map( - labelset: set, map_background: bool = True, background_label: str = constants.DEFAULT_BACKGROUND_LABEL + labelset: set, + map_background: bool = True, + background_label: str = constants.DEFAULT_BACKGROUND_LABEL, ) -> dict: """Convert set of labels to `dict` mapping those labels to a series of consecutive integers diff --git a/src/vak/datasets/__init__.py b/src/vak/datasets/__init__.py index cb7734cde..aa2f3be75 100644 --- a/src/vak/datasets/__init__.py +++ b/src/vak/datasets/__init__.py @@ -10,6 +10,4 @@ ] # TODO: make this a proper registry -DATASETS = { - "BioSoundSegBench": BioSoundSegBench -} +DATASETS = {"BioSoundSegBench": BioSoundSegBench} diff --git a/src/vak/datasets/biosoundsegbench.py b/src/vak/datasets/biosoundsegbench.py index a610c6530..f4b5369a8 100644 --- a/src/vak/datasets/biosoundsegbench.py +++ b/src/vak/datasets/biosoundsegbench.py @@ -1,16 +1,16 @@ """Class representing BioSoundSegBench dataset.""" + from __future__ import annotations import json import pathlib -from typing import Callable, Literal, TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Literal -from attrs import define import numpy as np import pandas as pd - import torch import torchvision.transforms +from attrs import define from .. import common, datapipes, transforms @@ -113,6 +113,7 @@ class TrainingReplicateMetadata: pre-defined training replicate in the BioSoundSegBench dataset. """ + biosound_group: str id: str | None frame_dur: float @@ -123,51 +124,57 @@ class TrainingReplicateMetadata: def metadata_from_splits_json_path( - splits_json_path: pathlib.Path, datset_path: pathlib.Path - ) -> TrainingReplicateMetadata: + splits_json_path: pathlib.Path, datset_path: pathlib.Path +) -> TrainingReplicateMetadata: + name = splits_json_path.name try: - # Human-Speech doesn't have ID or data source in filename - # so it will raise a ValueError - name = splits_json_path.name - (biosound_group, - id_, - timebin_dur_1st_half, - timebin_dur_2nd_half, - unit, - data_source, - train_dur_1st_half, - train_dur_2nd_half, - replicate_num, - _, _ - ) = name.split('.') + ( + biosound_group, + unit, + id_, + frame_dur_1st_half, + frame_dur_2nd_half, + data_source, + train_dur_1st_half, + train_dur_2nd_half, + replicate_num, + _, + _, + ) = name.split(".") + # Human-Speech doesn't have ID or data source in filename + # so it will raise a ValueError except ValueError: name = splits_json_path.name - (biosound_group, - timebin_dur_1st_half, - timebin_dur_2nd_half, - unit, - train_dur_1st_half, - train_dur_2nd_half, - replicate_num, - _, _ - ) = name.split('.') + ( + biosound_group, + unit, + frame_dur_1st_half, + frame_dur_2nd_half, + train_dur_1st_half, + train_dur_2nd_half, + replicate_num, + _, + _, + ) = name.split(".") id_ = None data_source = None if id_ is not None: - id_ = id_.split('-')[-1] - timebin_dur = float( - timebin_dur_1st_half.split('-')[-1] + '.' + timebin_dur_2nd_half.split('-')[0] + id_ = id_.split("-")[-1] + frame_dur = float( + frame_dur_1st_half.split("-")[-1] + + "." + + frame_dur_2nd_half.split("-")[0] ) train_dur = float( - train_dur_1st_half.split('-')[-1] + '.' + train_dur_2nd_half.split('-')[0] - ) - replicate_num = int( - replicate_num.split('-')[-1] + train_dur_1st_half.split("-")[-1] + + "." + + train_dur_2nd_half.split("-")[0] ) + replicate_num = int(replicate_num.split("-")[-1]) return TrainingReplicateMetadata( biosound_group, id_, - timebin_dur, + frame_dur, unit, data_source, train_dur, @@ -184,10 +191,9 @@ def __init__( frames_standardizer: FramesStandardizer | None = None, ): from ..transforms import FramesStandardizer # avoid circular import + if frames_standardizer is not None: - if isinstance( - frames_standardizer, FramesStandardizer - ): + if isinstance(frames_standardizer, FramesStandardizer): frames_transform = [frames_standardizer] else: raise TypeError( @@ -211,24 +217,30 @@ def __init__( self.frame_labels_transform = transforms.ToLongTensor() def __call__( - self, - frames: torch.Tensor, - multi_frame_labels: torch.Tensor | None = None, - binary_frame_labels: torch.Tensor | None = None, - boundary_frame_labels: torch.Tensor | None = None, - ) -> dict: + self, + frames: torch.Tensor, + multi_frame_labels: torch.Tensor | None = None, + binary_frame_labels: torch.Tensor | None = None, + boundary_frame_labels: torch.Tensor | None = None, + ) -> dict: frames = self.frames_transform(frames) item = { "frames": frames, } if multi_frame_labels is not None: - item["multi_frame_labels"] = self.frame_labels_transform(multi_frame_labels) + item["multi_frame_labels"] = self.frame_labels_transform( + multi_frame_labels + ) if binary_frame_labels is not None: - item["binary_frame_labels"] = self.frame_labels_transform(binary_frame_labels) + item["binary_frame_labels"] = self.frame_labels_transform( + binary_frame_labels + ) if boundary_frame_labels is not None: - item["boundary_frame_labels"] = self.frame_labels_transform(boundary_frame_labels) + item["boundary_frame_labels"] = self.frame_labels_transform( + boundary_frame_labels + ) return item @@ -285,9 +297,7 @@ def __init__( self.channel_dim = channel_dim if frames_standardizer is not None: - if not isinstance( - frames_standardizer, FramesStandardizer - ): + if not isinstance(frames_standardizer, FramesStandardizer): raise TypeError( f"Invalid type for frames_standardizer: {type(frames_standardizer)}. " "Should be an instance of vak.transforms.FramesStandardizer" @@ -335,13 +345,19 @@ def __call__( } if multi_frame_labels is not None: - item["multi_frame_labels"] = self.frame_labels_transform(multi_frame_labels) + item["multi_frame_labels"] = self.frame_labels_transform( + multi_frame_labels + ) if binary_frame_labels is not None: - item["binary_frame_labels"] = self.frame_labels_transform(binary_frame_labels) + item["binary_frame_labels"] = self.frame_labels_transform( + binary_frame_labels + ) if boundary_frame_labels is not None: - item["boundary_frame_labels"] = self.frame_labels_transform(boundary_frame_labels) + item["boundary_frame_labels"] = self.frame_labels_transform( + boundary_frame_labels + ) if padding_mask is not None: item["padding_mask"] = padding_mask @@ -355,6 +371,7 @@ def __call__( class BioSoundSegBench: """Class representing BioSoundSegBench dataset.""" + def __init__( self, dataset_path: str | pathlib.Path, @@ -369,7 +386,7 @@ def __init__( frame_labels_padval: int = -1, return_padding_mask: bool = False, return_frames_path: bool = False, - item_transform: Callable | None = None + item_transform: Callable | None = None, ): """BioSoundSegBench dataset.""" # ---- validate args, roughly in order @@ -387,7 +404,9 @@ def __init__( splits_path = pathlib.Path(splits_path) if not splits_path.exists(): - tmp_splits_path = dataset_path / "splits" / "splits-jsons" / splits_path + tmp_splits_path = ( + dataset_path / "splits" / "splits-jsons" / splits_path + ) if not tmp_splits_path.exists(): raise FileNotFoundError( f"Did not find `splits_path` using either absolute path ({splits_path})" @@ -413,10 +432,9 @@ def __init__( f"Valid `target_type` arguments are: {VALID_TARGET_TYPES}" ) if isinstance(target_type, (list, tuple)): - if not all([ - isinstance(target_type_, str) - for target_type_ in target_type - ]): + if not all( + [isinstance(target_type_, str) for target_type_ in target_type] + ): types_in_target_types = set( [type(target_type_) for target_type_ in target_type] ) @@ -443,13 +461,15 @@ def __init__( self.training_replicate_metadata = metadata_from_splits_json_path( self.splits_path, self.dataset_path ) - self.frame_dur = self.training_replicate_metadata.frame_dur * 1e-3 # convert from ms to s! + self.frame_dur = ( + self.training_replicate_metadata.frame_dur * 1e-3 + ) # convert from ms to s! if "multi_frame_labels" in target_type: labelmaps_json_path = self.dataset_path / "labelmaps.json" if not labelmaps_json_path.exists(): raise FileNotFoundError( - "`target_type` includes \"multi_frame_labels\" but " + '`target_type` includes "multi_frame_labels" but ' "'labelmaps.json' was not found in root of dataset path:\n" f"{labelmaps_json_path}" ) @@ -472,10 +492,10 @@ def __init__( f"group '{group}', unit '{unit}', and id '{id}'. " "Please check that splits_json path is correct." ) - elif target_type == ('binary_frame_labels',): - self.labelmap = {'no segment': 0, 'segment': 1} - elif target_type == ('boundary_frame_labels',): - self.labelmap = {'no boundary': 0, 'boundary': 1} + elif target_type == ("binary_frame_labels",): + self.labelmap = {"no segment": 0, "segment": 1} + elif target_type == ("boundary_frame_labels",): + self.labelmap = {"no boundary": 0, "boundary": 1} self.split = split split_df = pd.read_csv(self.splits_metadata.splits_csv_path) @@ -508,15 +528,20 @@ def __init__( self.inds_in_sample = np.load( getattr(self.splits_metadata.inds_in_sample_vector_paths, split) ) - self.window_inds = datapipes.frame_classification.train_datapipe.get_window_inds( - self.sample_ids.shape[-1], window_size, stride + self.window_inds = ( + datapipes.frame_classification.train_datapipe.get_window_inds( + self.sample_ids.shape[-1], window_size, stride + ) ) if item_transform is None: if standardize_frames and frames_standardizer is None: from ..transforms import FramesStandardizer - frames_standardizer = FramesStandardizer.fit_inputs_targets_csv_path( - self.splits_metadata.splits_csv_path, self.dataset_path + + frames_standardizer = ( + FramesStandardizer.fit_inputs_targets_csv_path( + self.splits_metadata.splits_csv_path, self.dataset_path + ) ) if split == "train": self.item_transform = TrainItemTransform( @@ -580,7 +605,8 @@ def _getitem_train(self, idx): item["frames"] = spect_dict[common.constants.SPECT_KEY] for target_type in self.target_type: item[target_type] = np.load( - self.dataset_path / self.target_paths[target_type][sample_id] + self.dataset_path + / self.target_paths[target_type][sample_id] ) elif len(uniq_sample_ids) > 1: @@ -592,9 +618,7 @@ def _getitem_train(self, idx): for sample_id in sorted(uniq_sample_ids): frames_path = self.dataset_path / self.frames_paths[sample_id] spect_dict = common.files.spect.load(frames_path) - item["frames"].append( - spect_dict[common.constants.SPECT_KEY] - ) + item["frames"].append(spect_dict[common.constants.SPECT_KEY]) for target_type in self.target_type: item[target_type].append( np.load( diff --git a/src/vak/datasets/get.py b/src/vak/datasets/get.py index 3b38182a4..011a3df7b 100644 --- a/src/vak/datasets/get.py +++ b/src/vak/datasets/get.py @@ -1,7 +1,8 @@ """Helper function that gets instances of classes representing datasets built into :mod:`vak`.""" + from __future__ import annotations -from typing import Literal, Mapping, TYPE_CHECKING +from typing import TYPE_CHECKING, Literal, Mapping from .. import common @@ -10,11 +11,12 @@ if TYPE_CHECKING: from ..transforms import FramesStandardizer + def get( - dataset_config: dict, - split: Literal["predict", "test", "train", "val"], - frames_standardizer: FramesStandardizer | None = None, - ) -> Dataset: + dataset_config: dict, + split: Literal["predict", "test", "train", "val"], + frames_standardizer: FramesStandardizer | None = None, +) -> Dataset: """Get an instance of a dataset class from :mod:`vak.datasets`. Parameters @@ -38,7 +40,7 @@ def get( raise KeyError( "A name is required to get a dataset, but " "`vak.datasets.get` received a `dataset_config` " - f"without a \"name\":\n{dataset_config}" + f'without a "name":\n{dataset_config}' ) if split not in common.constants.VALID_SPLITS: raise ValueError( @@ -46,7 +48,6 @@ def get( f"Valid splits are: {common.constants.VALID_SPLITS}" ) - from .. import datasets dataset_name = dataset_config["name"] try: dataset_class = DATASETS[dataset_name] @@ -54,13 +55,13 @@ def get( raise ValueError( f"Invalid dataset name: {dataset_name}\n." f"Built-in dataset names are: {DATASETS.keys()}" - ) + ) from e if frames_standardizer is not None: dataset_config["params"]["frames_standardizer"] = frames_standardizer dataset = dataset_class( dataset_path=dataset_config["path"], splits_path=dataset_config["splits_path"], split=split, - **dataset_config["params"] + **dataset_config["params"], ) return dataset diff --git a/src/vak/eval/frame_classification.py b/src/vak/eval/frame_classification.py index ea20b72e6..12688288f 100644 --- a/src/vak/eval/frame_classification.py +++ b/src/vak/eval/frame_classification.py @@ -113,7 +113,9 @@ def eval_frame_classification_model( ) frames_standardizer = joblib.load(frames_standardizer_path) else: - logger.info("No `frames_standardizer_path` provided, not standardizing frames.") + logger.info( + "No `frames_standardizer_path` provided, not standardizing frames." + ) frames_standardizer = None logger.info(f"loading labelmap from path: {labelmap_path}") @@ -150,7 +152,7 @@ def eval_frame_classification_model( frames_standardizer=frames_standardizer, return_padding_mask=True, ) - # ---- *yes* using a built-in dataset ------------------------------------------------------------------------------# ---- *yes* using a built-in dataset ------------------------------------------------------------------------------ + # ---- *yes* using a built-in dataset ------------------------------------------------------------------------------ else: dataset_config["params"]["return_padding_mask"] = True val_dataset = datasets.get( diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index edd691685..2da299cf3 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -94,7 +94,7 @@ def __init__( optimizer: torch.optim.Optimizer | None = None, metrics: dict | None = None, post_tfm: Callable | None = None, - background_label = common.constants.DEFAULT_BACKGROUND_LABEL, + background_label=common.constants.DEFAULT_BACKGROUND_LABEL, ): """Initialize a new instance of a :class:`~vak.models.frame_classification_model.FrameClassificationModel`. @@ -141,7 +141,9 @@ def __init__( # with single-character labels # so that we do not affect edit distance computation # see https://github.com/NickleDave/vak/issues/373 - labelmap_keys = [lbl for lbl in labelmap.keys() if lbl != background_label] + labelmap_keys = [ + lbl for lbl in labelmap.keys() if lbl != background_label + ] if any( [len(label) > 1 for label in labelmap_keys] ): # only re-map if necessary @@ -211,29 +213,29 @@ def training_step(self, batch: tuple, batch_idx: int): # we repeat this code in validation step # because I'm assuming it's faster than a call to a staticmethod that factors it out if ( # multi-class frame classificaton - "multi_frame_labels" in batch and - "binary_frame_labels" not in batch and - "boundary_frame_labels" not in batch - ): - target_types = ("multi_frame_labels",) + "multi_frame_labels" in batch + and "binary_frame_labels" not in batch + and "boundary_frame_labels" not in batch + ): + target_types = ("multi_frame_labels",) elif ( # binary frame classification - "binary_frame_labels" in batch and - "multi_frame_labels" not in batch and - "boundary_frame_labels" not in batch - ): - target_types = ("binary_frame_labels",) + "binary_frame_labels" in batch + and "multi_frame_labels" not in batch + and "boundary_frame_labels" not in batch + ): + target_types = ("binary_frame_labels",) elif ( # boundary "detection" -- i.e. different kind of binary frame classification - "boundary_frame_labels" in batch and - "multi_frame_labels" not in batch and - "binary_frame_labels" not in batch - ): - target_types = ("boundary_frame_labels",) + "boundary_frame_labels" in batch + and "multi_frame_labels" not in batch + and "binary_frame_labels" not in batch + ): + target_types = ("boundary_frame_labels",) elif ( # multi-class frame classification *and* boundary detection - "multi_frame_labels" in batch and - "boundary_frame_labels" in batch and - "binary_frame_labels" not in batch - ): - target_types = ("multi_frame_labels", "boundary_frame_labels") + "multi_frame_labels" in batch + and "boundary_frame_labels" in batch + and "binary_frame_labels" not in batch + ): + target_types = ("multi_frame_labels", "boundary_frame_labels") if len(target_types) == 1: class_logits = self.network(frames) @@ -242,14 +244,17 @@ def training_step(self, batch: tuple, batch_idx: int): else: multi_logits, boundary_logits = self.network(frames) loss = self.loss( - multi_logits, boundary_logits, batch["multi_frame_labels"], batch["boundary_frame_labels"] - ) + multi_logits, + boundary_logits, + batch["multi_frame_labels"], + batch["boundary_frame_labels"], + ) if isinstance(loss, torch.Tensor): - self.log("train_loss", loss, on_step=True) + self.log("train_loss", loss, on_step=True) elif isinstance(loss, dict): - # this provides a mechanism to values for all terms of a loss function with multiple terms - for loss_name, loss_val in loss.items(): - self.log(f"train_{loss_name}", loss_val, on_step=True) + # this provides a mechanism to values for all terms of a loss function with multiple terms + for loss_name, loss_val in loss.items(): + self.log(f"train_{loss_name}", loss_val, on_step=True) return loss @@ -282,29 +287,29 @@ def validation_step(self, batch: tuple, batch_idx: int): # we repeat this code in training step # because I'm assuming it's faster than a call to a staticmethod that factors it out if ( # multi-class frame classificaton - "multi_frame_labels" in batch and - "binary_frame_labels" not in batch and - "boundary_frame_labels" not in batch - ): - target_types = ("multi_frame_labels",) + "multi_frame_labels" in batch + and "binary_frame_labels" not in batch + and "boundary_frame_labels" not in batch + ): + target_types = ("multi_frame_labels",) elif ( # binary frame classification - "binary_frame_labels" in batch and - "multi_frame_labels" not in batch and - "boundary_frame_labels" not in batch - ): - target_types = ("binary_frame_labels",) + "binary_frame_labels" in batch + and "multi_frame_labels" not in batch + and "boundary_frame_labels" not in batch + ): + target_types = ("binary_frame_labels",) elif ( # boundary "detection" -- i.e. different kind of binary frame classification - "boundary_frame_labels" in batch and - "multi_frame_labels" not in batch and - "binary_frame_labels" not in batch - ): - target_types = ("boundary_frame_labels",) + "boundary_frame_labels" in batch + and "multi_frame_labels" not in batch + and "binary_frame_labels" not in batch + ): + target_types = ("boundary_frame_labels",) elif ( # multi-class frame classification *and* boundary detection - "multi_frame_labels" in batch and - "boundary_frame_labels" in batch and - "binary_frame_labels" not in batch - ): - target_types = ("multi_frame_labels", "boundary_frame_labels") + "multi_frame_labels" in batch + and "boundary_frame_labels" in batch + and "binary_frame_labels" not in batch + ): + target_types = ("multi_frame_labels", "boundary_frame_labels") if len(target_types) == 1: class_logits = self.network(frames) @@ -350,11 +355,13 @@ def validation_step(self, batch: tuple, batch_idx: int): class_preds = class_preds[:, padding_mask] if boundary_logits is not None: - boundary_logits = boundary_logits[:, :, padding_mask] - boundary_preds = boundary_preds[:, padding_mask] + boundary_logits = boundary_logits[:, :, padding_mask] + boundary_preds = boundary_preds[:, padding_mask] if "multi_frame_labels" in target_types: - multi_frame_labels_str = self.to_labels_eval(batch["multi_frame_labels"].cpu().numpy()) + multi_frame_labels_str = self.to_labels_eval( + batch["multi_frame_labels"].cpu().numpy() + ) class_preds_str = self.to_labels_eval(class_preds.cpu().numpy()) if self.post_tfm: @@ -363,12 +370,16 @@ def validation_step(self, batch: tuple, batch_idx: int): ) class_preds_tfm_str = self.to_labels_eval(class_preds_tfm) # convert back to tensor so we can compute accuracy - class_preds_tfm = torch.from_numpy(class_preds_tfm).to(self.device) + class_preds_tfm = torch.from_numpy(class_preds_tfm).to( + self.device + ) if len(target_types) == 1: target = batch[target_types[0]] else: - target = {target_type: batch[target_type] for target_type in target_types} + target = { + target_type: batch[target_type] for target_type in target_types + } for metric_name, metric_callable in self.metrics.items(): if metric_name == "loss": @@ -382,8 +393,11 @@ def validation_step(self, batch: tuple, batch_idx: int): ) else: loss = self.loss( - class_logits, boundary_logits, batch["multi_frame_labels"], batch["boundary_frame_labels"] - ) + class_logits, + boundary_logits, + batch["multi_frame_labels"], + batch["boundary_frame_labels"], + ) if isinstance(loss, torch.Tensor): self.log( f"val_{metric_name}", @@ -422,14 +436,18 @@ def validation_step(self, batch: tuple, batch_idx: int): else: self.log( f"val_{metric_name}", - metric_callable(class_preds, target["multi_frame_labels"]), + metric_callable( + class_preds, target["multi_frame_labels"] + ), batch_size=1, on_step=True, sync_dist=True, ) self.log( f"val_boundary_{metric_name}", - metric_callable(boundary_preds, target["boundary_frame_labels"]), + metric_callable( + boundary_preds, target["boundary_frame_labels"] + ), batch_size=1, on_step=True, sync_dist=True, @@ -437,7 +455,9 @@ def validation_step(self, batch: tuple, batch_idx: int): if self.post_tfm and "multi_frame_labels" in target_types: self.log( f"val_multi_{metric_name}_tfm", - metric_callable(class_preds_tfm, target["multi_frame_labels"]), + metric_callable( + class_preds_tfm, target["multi_frame_labels"] + ), batch_size=1, on_step=True, sync_dist=True, @@ -449,7 +469,11 @@ def validation_step(self, batch: tuple, batch_idx: int): self.log( f"val_{metric_name}", # next line: convert to float to squelch warning from lightning - float(metric_callable(class_preds_str, multi_frame_labels_str)), + float( + metric_callable( + class_preds_str, multi_frame_labels_str + ) + ), batch_size=1, on_step=True, sync_dist=True, @@ -458,7 +482,11 @@ def validation_step(self, batch: tuple, batch_idx: int): self.log( f"val_{metric_name}_tfm", # next line: convert to float to squelch warning from lightning - float(metric_callable(class_preds_tfm_str, multi_frame_labels_str)), + float( + metric_callable( + class_preds_tfm_str, multi_frame_labels_str + ) + ), batch_size=1, on_step=True, sync_dist=True, @@ -485,7 +513,10 @@ def predict_step(self, batch: tuple, batch_idx: int): containing the spectrogram for which a prediction was generated. """ - frames, frames_path = batch["frames"].to(self.device), batch["frames_path"] + frames, frames_path = ( + batch["frames"].to(self.device), + batch["frames_path"], + ) if isinstance(frames_path, list) and len(frames_path) == 1: frames_path = frames_path[0] # TODO: fix this weirdness. Diff't collate_fn? diff --git a/src/vak/nn/loss/__init__.py b/src/vak/nn/loss/__init__.py index 59e0194ce..92b262734 100644 --- a/src/vak/nn/loss/__init__.py +++ b/src/vak/nn/loss/__init__.py @@ -3,6 +3,7 @@ from .umap import UmapLoss, umap_loss __all__ = [ + "CrossEntropyLoss", "DiceLoss", "dice_loss", "UmapLoss", diff --git a/src/vak/nn/loss/crossentropy.py b/src/vak/nn/loss/crossentropy.py index eb0882629..666d61dc4 100644 --- a/src/vak/nn/loss/crossentropy.py +++ b/src/vak/nn/loss/crossentropy.py @@ -1,17 +1,33 @@ import torch + class CrossEntropyLoss(torch.nn.CrossEntropyLoss): """Wrapper around :class:`torch.nn.CrossEntropyLoss` Converts the argument ``weight`` to a :class:`torch.Tensor` if it is a :class:`list`. """ - def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, - reduction='mean', label_smoothing=0.0): + + def __init__( + self, + weight=None, + size_average=None, + ignore_index=-100, + reduce=None, + reduction="mean", + label_smoothing=0.0, + ): if weight is not None: if isinstance(weight, torch.Tensor): pass elif isinstance(weight, list): weight = torch.Tensor(weight) - super().__init__(weight, size_average, ignore_index, reduce, reduction, label_smoothing) + super().__init__( + weight, + size_average, + ignore_index, + reduce, + reduction, + label_smoothing, + ) diff --git a/src/vak/predict/frame_classification.py b/src/vak/predict/frame_classification.py index 57d06d852..eff70ef48 100644 --- a/src/vak/predict/frame_classification.py +++ b/src/vak/predict/frame_classification.py @@ -7,13 +7,13 @@ import os import pathlib -from attrs import define import crowsetta import joblib import lightning -import pandas as pd import numpy as np +import pandas as pd import torch.utils.data +from attrs import define from tqdm import tqdm from .. import common, datapipes, datasets, models, transforms @@ -23,7 +23,6 @@ logger = logging.getLogger(__name__) - @define class AnnotationDataFrame: """Data class that represents annotations @@ -32,8 +31,9 @@ class AnnotationDataFrame: Used to save annotations that currently can't be saved with :mod:`crowsetta`, e.g. boundary times. """ + df: pd.DataFrame - audio_path : str | pathlib.Path + audio_path: str | pathlib.Path def predict_with_frame_classification_model( @@ -43,7 +43,7 @@ def predict_with_frame_classification_model( checkpoint_path: str | pathlib.Path, labelmap_path: str | pathlib.Path, num_workers: int = 2, - timebins_key:str = "t", + timebins_key: str = "t", frames_standardizer_path: str | pathlib.Path | None = None, annot_csv_filename: str | None = None, output_dir: str | pathlib.Path | None = None, @@ -201,6 +201,7 @@ def predict_with_frame_classification_model( # but fail early here if we don't have it if "target_type" not in dataset_config["params"]: from ..datasets.biosoundsegbench import VALID_TARGET_TYPES + raise ValueError( "The dataset table in the configuration file requires a 'target_type' " "when running predictions on built-in datasets. " @@ -318,7 +319,10 @@ def predict_with_frame_classification_model( if isinstance(frames_path, list) and len(frames_path) == 1: frames_path = frames_path[0] # we do all this basically to have clear naming below - if target_type == "multi_frame_labels" or target_type == "binary_frame_labels": + if ( + target_type == "multi_frame_labels" + or target_type == "binary_frame_labels" + ): class_logits = pred_dict[frames_path] boundary_logits = None elif target_type == "boundary_frame_labels": @@ -342,11 +346,19 @@ def predict_with_frame_classification_model( np.savez(net_output_path, net_output) if class_logits is not None: - class_preds = torch.argmax(class_logits, dim=1) # assumes class dimension is 1 - class_preds = torch.flatten(class_preds).cpu().numpy()[padding_mask] + class_preds = torch.argmax( + class_logits, dim=1 + ) # assumes class dimension is 1 + class_preds = ( + torch.flatten(class_preds).cpu().numpy()[padding_mask] + ) if boundary_logits is not None: - boundary_preds = torch.argmax(boundary_logits, dim=1) # assumes class dimension is 1 - boundary_preds = torch.flatten(boundary_preds).cpu().numpy()[padding_mask] + boundary_preds = torch.argmax( + boundary_logits, dim=1 + ) # assumes class dimension is 1 + boundary_preds = ( + torch.flatten(boundary_preds).cpu().numpy()[padding_mask] + ) if input_type == "audio": frames, samplefreq = constants.AUDIO_FORMAT_FUNC_MAP[audio_format]( @@ -361,11 +373,16 @@ def predict_with_frame_classification_model( # audio_fname is used for audio_path attribute of crowsetta.Annotation below audio_fname = files.spect.find_audio_fname(frames_path) - if target_type == "multi_frame_labels" or target_type == "binary_frame_labels": + if ( + target_type == "multi_frame_labels" + or target_type == "binary_frame_labels" + ): if majority_vote or min_segment_dur: if background_label in labelmap: background_label = labelmap[background_label] - elif "unlabeled" in labelmap: # some backward compatibility here + elif ( + "unlabeled" in labelmap + ): # some backward compatibility here background_label = labelmap["unlabeled"] else: background_label = 0 # set a default value anyway just to not throw an error @@ -390,20 +407,22 @@ def predict_with_frame_classification_model( ) annot = crowsetta.Annotation( - seq=seq, notated_path=audio_fname, annot_path=annot_csv_path.name + seq=seq, + notated_path=audio_fname, + annot_path=annot_csv_path.name, ) annots.append(annot) elif target_type == "boundary_frame_labels": - boundary_inds = transforms.frame_labels.boundary_inds_from_boundary_labels( - boundary_preds, - force_boundary_first_ind=True, + boundary_inds = ( + transforms.frame_labels.boundary_inds_from_boundary_labels( + boundary_preds, + force_boundary_first_ind=True, + ) ) boundary_times = frame_times[boundary_inds] # fancy indexing - df = pd.DataFrame.from_records({'boundary_time': boundary_times}) - annots.append( - AnnotationDataFrame(df=df, audio_path=audio_fname) - ) + df = pd.DataFrame.from_records({"boundary_time": boundary_times}) + annots.append(AnnotationDataFrame(df=df, audio_path=audio_fname)) elif target_type == ("boundary_frame_labels", "multi_frame_labels"): if majority_vote is False: logger.warn( @@ -417,7 +436,9 @@ def predict_with_frame_classification_model( elif "unlabeled" in labelmap: # some backward compatibility here background_label = labelmap["unlabeled"] else: - background_label = 0 # set a default value anyway just to not throw an error + background_label = ( + 0 # set a default value anyway just to not throw an error + ) # Notice here we *always* call post-process, with majority_vote=True # because we are using boundary labels class_preds = transforms.frame_labels.postprocess( @@ -433,10 +454,6 @@ def predict_with_frame_classification_model( labelmap=labelmap, frame_times=frame_times, ) - if labels is None and onsets_s is None and offsets_s is None: - # handle the case when all time bins are predicted to be unlabeled - # see https://github.com/NickleDave/vak/issues/383 - continue if labels is None and onsets_s is None and offsets_s is None: # handle the case when all time bins are predicted to be unlabeled # see https://github.com/NickleDave/vak/issues/383 @@ -446,7 +463,9 @@ def predict_with_frame_classification_model( ) annot = crowsetta.Annotation( - seq=seq, notated_path=audio_fname, annot_path=annot_csv_path.name + seq=seq, + notated_path=audio_fname, + annot_path=annot_csv_path.name, ) annots.append(annot) @@ -457,8 +476,8 @@ def predict_with_frame_classification_model( df_out = [] for sample_num, annot_df in enumerate(annots): df = annot_df.df - df['audio_path'] = str(annot_df.audio_path) - df['sample_num'] = sample_num + df["audio_path"] = str(annot_df.audio_path) + df["sample_num"] = sample_num df_out.append(df) df_out = pd.concat(df_out) df_out.to_csv(annot_csv_path, index=False) diff --git a/src/vak/prep/frame_classification/learncurve.py b/src/vak/prep/frame_classification/learncurve.py index 884e37469..001d7243a 100644 --- a/src/vak/prep/frame_classification/learncurve.py +++ b/src/vak/prep/frame_classification/learncurve.py @@ -176,7 +176,7 @@ def make_subsets_from_dataset_df( num_replicates: int, dataset_path: pathlib.Path, labelmap: dict, - background_label : str = common.constants.DEFAULT_BACKGROUND_LABEL, + background_label: str = common.constants.DEFAULT_BACKGROUND_LABEL, ) -> pd.DataFrame: """Make subsets of the training data split for a learning curve. diff --git a/src/vak/train/frame_classification.py b/src/vak/train/frame_classification.py index ed658c6ed..f83f6518f 100644 --- a/src/vak/train/frame_classification.py +++ b/src/vak/train/frame_classification.py @@ -200,19 +200,22 @@ def train_frame_classification_model( "No `frames_standardizer_path` provided, not loading", ) logger.info("Will standardize (normalize) frames") - frames_standardizer = transforms.FramesStandardizer.fit_dataset_path( - dataset_path, - split="train", - subset=subset, + frames_standardizer = ( + transforms.FramesStandardizer.fit_dataset_path( + dataset_path, + split="train", + subset=subset, + ) ) joblib.dump( - frames_standardizer, results_path.joinpath("FramesStandardizer") + frames_standardizer, + results_path.joinpath("FramesStandardizer"), ) elif frames_standardizer_path is not None and not standardize_frames: raise ValueError( "`frames_standardizer_path` provided but `standardize_frames` was False, these options conflict" ) - # ---- *yes* using a built-in dataset ------------------------------------------------------------------------------ + # ---- *yes* using a built-in dataset -------------------------------------------------------------------------- else: # not standardize_frames and frames_standardizer_path is None: logger.info( @@ -235,13 +238,13 @@ def train_frame_classification_model( # while still accepting a transform but defaulting to None) if "standardize_frames" not in dataset_config: logger.info( - f"Adding `standardize_frames` argument to dataset_config[\"params\"]: {standardize_frames}" + f'Adding `standardize_frames` argument to dataset_config["params"]: {standardize_frames}' ) dataset_config["params"]["standardize_frames"] = standardize_frames train_dataset = datasets.get( dataset_config, split="train", - ) + ) logger.info( f"Duration of a frame in dataset, in seconds: {train_dataset.frame_dur}", ) @@ -249,13 +252,16 @@ def train_frame_classification_model( labelmap = train_dataset.labelmap with open(results_path.joinpath("labelmap.json"), "w") as fp: json.dump(labelmap, fp) - frames_standardizer = getattr(train_dataset.item_transform, 'frames_standardizer') + frames_standardizer = getattr( + train_dataset.item_transform, "frames_standardizer" + ) if frames_standardizer is not None: logger.info( - f"Saving `frames_standardizer` from item transform on training dataset" + "Saving `frames_standardizer` from item transform on training dataset" ) joblib.dump( - frames_standardizer, results_path.joinpath("FramesStandardizer") + frames_standardizer, + results_path.joinpath("FramesStandardizer"), ) logger.info( @@ -274,7 +280,9 @@ def train_frame_classification_model( f"Will measure error on validation set every {val_step} steps of training", ) if dataset_config["name"] is None: - logger.info(f"Using validation split from dataset:\n{dataset_path}") + logger.info( + f"Using validation split from dataset:\n{dataset_path}" + ) val_dur = get_split_dur(dataset_df, "val") logger.info( f"Total duration of validation split from dataset (in s): {val_dur}", diff --git a/src/vak/transforms/frame_labels/functional.py b/src/vak/transforms/frame_labels/functional.py index 85f435eea..09e9eaf2e 100644 --- a/src/vak/transforms/frame_labels/functional.py +++ b/src/vak/transforms/frame_labels/functional.py @@ -83,7 +83,9 @@ def from_segments( "labels_int must be a list or numpy.ndarray of integers" ) - label_vec = np.ones((time_bins.shape[-1],), dtype="int8") * background_label + label_vec = ( + np.ones((time_bins.shape[-1],), dtype="int8") * background_label + ) onset_inds = [np.argmin(np.abs(time_bins - onset)) for onset in onsets_s] offset_inds = [ np.argmin(np.abs(time_bins - offset)) for offset in offsets_s @@ -96,8 +98,9 @@ def from_segments( def to_labels( - frame_labels: npt.NDArray, labelmap: dict, - background_label: str = common.constants.DEFAULT_BACKGROUND_LABEL + frame_labels: npt.NDArray, + labelmap: dict, + background_label: str = common.constants.DEFAULT_BACKGROUND_LABEL, ) -> str: """Convert vector of frame labels to a string, one character for each continuous segment. @@ -156,7 +159,7 @@ def to_segments( labelmap: dict, frame_times: npt.NDArray, n_decimals_trunc: int = 5, - background_label: str = common.constants.DEFAULT_BACKGROUND_LABEL + background_label: str = common.constants.DEFAULT_BACKGROUND_LABEL, ) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray]: """Convert a vector of frame labels into segments in the form of onset indices, @@ -382,8 +385,7 @@ def take_majority_vote( def boundary_inds_from_boundary_labels( - boundary_labels: npt.NDArray, - force_boundary_first_ind: bool = True + boundary_labels: npt.NDArray, force_boundary_first_ind: bool = True ) -> npt.NDArray: """Return a :class:`numpy.ndarray` with the indices of boundaries, given a 1-D vector of boundary labels. @@ -409,8 +411,7 @@ def boundary_inds_from_boundary_labels( def segment_inds_list_from_boundary_labels( - boundary_labels: npt.NDArray, - force_boundary_first_ind: bool = True + boundary_labels: npt.NDArray, force_boundary_first_ind: bool = True ) -> list[npt.NDArray]: """Given an array of boundary labels, return a list of :class:`numpy.ndarray` vectors, @@ -434,7 +435,9 @@ def segment_inds_list_from_boundary_labels( Of fancy indexing arrays. Each array can be used to index one segment in ``frame_labels``. """ - boundary_inds = boundary_inds_from_boundary_labels(boundary_labels, force_boundary_first_ind) + boundary_inds = boundary_inds_from_boundary_labels( + boundary_labels, force_boundary_first_ind + ) # at the end of `boundary_inds``, insert an imaginary "last" boundary we use just with ``np.arange`` below np.insert(boundary_inds, boundary_inds.shape[0], boundary_labels.shape[0]) @@ -452,7 +455,7 @@ def postprocess( background_label: int = 0, min_segment_dur: float | None = None, majority_vote: bool = False, - boundary_labels: npt.NDArray | None = None + boundary_labels: npt.NDArray | None = None, ) -> npt.NDArray: """Apply post-processing transformations to a vector of frame labels. @@ -532,7 +535,10 @@ def postprocess( # handle the case when all time bins are predicted to be unlabeled # see https://github.com/NickleDave/vak/issues/383 uniq_frame_labels = np.unique(frame_labels) - if len(uniq_frame_labels) == 1 and uniq_frame_labels[0] == background_label: + if ( + len(uniq_frame_labels) == 1 + and uniq_frame_labels[0] == background_label + ): return frame_labels # -> no need to do any of the post-processing if boundary_labels is not None: diff --git a/tests/test_datasets/conftest.py b/tests/test_datasets/conftest.py index 16699079d..f1742e659 100644 --- a/tests/test_datasets/conftest.py +++ b/tests/test_datasets/conftest.py @@ -7,16 +7,16 @@ SPLITS_JSON = { - "splits_csv_path": "splits/inputs-targets-paths-csvs/Mouse-Pup-Call.id-SW.timebin-1.5-ms.call.id-data-only.train-dur-1500.0.replicate-1.splits.csv", + "splits_csv_path": "splits/inputs-targets-paths-csvs/Mouse-Pup-Call.call.id-SW.frame-dur-1.5-ms.id-data-only.train-dur-1500.0.replicate-1.splits.csv", "sample_id_vec_path": { - "test": "splits/sample-id-vectors/Mouse-Pup-Call.id-SW.timebin-1.5-ms.call.id-data-only.train-dur-1500.0.replicate-1.splits.test.sample_ids.npy", - "train": "splits/sample-id-vectors/Mouse-Pup-Call.id-SW.timebin-1.5-ms.call.id-data-only.train-dur-1500.0.replicate-1.splits.train.sample_ids.npy", - "val": "splits/sample-id-vectors/Mouse-Pup-Call.id-SW.timebin-1.5-ms.call.id-data-only.train-dur-1500.0.replicate-1.splits.val.sample_ids.npy" + "test": "splits/sample-id-vectors/Mouse-Pup-Call.call.id-SW.frame-dur-1.5-ms.id-data-only.train-dur-1500.0.replicate-1.splits.test.sample_ids.npy", + "train": "splits/sample-id-vectors/Mouse-Pup-Call.call.id-SW.frame-dur-1.5-ms.id-data-only.train-dur-1500.0.replicate-1.splits.train.sample_ids.npy", + "val": "splits/sample-id-vectors/Mouse-Pup-Call.call.id-SW.frame-dur-1.5-ms.id-data-only.train-dur-1500.0.replicate-1.splits.val.sample_ids.npy" }, "inds_in_sample_vec_path": { - "test": "splits/inds-in-sample-vectors/Mouse-Pup-Call.id-SW.timebin-1.5-ms.call.id-data-only.train-dur-1500.0.replicate-1.splits.test.inds_in_sample.npy", - "train": "splits/inds-in-sample-vectors/Mouse-Pup-Call.id-SW.timebin-1.5-ms.call.id-data-only.train-dur-1500.0.replicate-1.splits.train.inds_in_sample.npy", - "val": "splits/inds-in-sample-vectors/Mouse-Pup-Call.id-SW.timebin-1.5-ms.call.id-data-only.train-dur-1500.0.replicate-1.splits.val.inds_in_sample.npy" + "test": "splits/inds-in-sample-vectors/Mouse-Pup-Call.call.id-SW.frame-dur-1.5-ms.id-data-only.train-dur-1500.0.replicate-1.splits.test.inds_in_sample.npy", + "train": "splits/inds-in-sample-vectors/Mouse-Pup-Call.call.id-SW.frame-dur-1.5-ms.id-data-only.train-dur-1500.0.replicate-1.splits.train.inds_in_sample.npy", + "val": "splits/inds-in-sample-vectors/Mouse-Pup-Call.call.id-SW.frame-dur-1.5-ms.id-data-only.train-dur-1500.0.replicate-1.splits.val.inds_in_sample.npy" } } @@ -349,7 +349,7 @@ def mock_biosoundsegbench_dataset(tmp_path): inputs_targets_csv_dir = splits_dir / "inputs-targets-paths-csvs" inputs_targets_csv_dir.mkdir() df = pd.DataFrame.from_records(INPUTS_TARGETS_CSV_RECORDS) - splits_csv = df.to_csv(inputs_targets_csv_dir / "Mouse-Pup-Call.id-SW.timebin-1.5-ms.call.id-data-only.train-dur-1500.0.replicate-1.splits.csv") + splits_csv = df.to_csv(inputs_targets_csv_dir / "Mouse-Pup-Call.call.id-SW.frame-dur-1.5-ms.id-data-only.train-dur-1500.0.replicate-1.splits.csv") df.to_csv(splits_csv) sample_id_vecs_dir = splits_dir / "sample-id-vectors" @@ -359,11 +359,11 @@ def mock_biosoundsegbench_dataset(tmp_path): for split in "train", "val", "test": sample_id_vec = np.zeros(10) - np.save(sample_id_vecs_dir / f"Mouse-Pup-Call.id-SW.timebin-1.5-ms.call.id-data-only.train-dur-1500.0.replicate-1.splits.{split}.sample_ids.npy", sample_id_vec) + np.save(sample_id_vecs_dir / f"Mouse-Pup-Call.call.id-SW.frame-dur-1.5-ms.id-data-only.train-dur-1500.0.replicate-1.splits.{split}.sample_ids.npy", sample_id_vec) inds_in_sample_vec = np.arange(10) - np.save(inds_in_sample_vecs_dir / f"Mouse-Pup-Call.id-SW.timebin-1.5-ms.call.id-data-only.train-dur-1500.0.replicate-1.splits.{split}.inds_in_sample.npy", inds_in_sample_vec) + np.save(inds_in_sample_vecs_dir / f"Mouse-Pup-Call.call.id-SW.frame-dur-1.5-ms.id-data-only.train-dur-1500.0.replicate-1.splits.{split}.inds_in_sample.npy", inds_in_sample_vec) - splits_path = dataset_path / "Mouse-Pup-Call.id-SW.timebin-1.5-ms.call.id-data-only.train-dur-1500.0.replicate-1.splits.json" + splits_path = dataset_path / "Mouse-Pup-Call.call.id-SW.frame-dur-1.5-ms.id-data-only.train-dur-1500.0.replicate-1.splits.json" with splits_path.open('w') as fp: json.dump(SPLITS_JSON, fp)