Skip to content

Commit

Permalink
BUG: minor fixes to datasets.BioSoundSegBench class (#766)
Browse files Browse the repository at this point in the history
* Fix how we get metadata from splits json path in datasets.biosoundsegbench

* CLN: Remove duplicated code in predict/frame_classification.py

* Apply formatting from linters

* Make flake8 fixes

* Fix SPLITS_JSON in test_datasets/conftest.py to match the changed naming scheme for the BioSoundSegBench dataset

* Set '--slow-last' option when we call pytest in nox session 'coverage' so that tests fail faster on CI
  • Loading branch information
NickleDave authored Jul 6, 2024
1 parent d6331d5 commit 3763838
Show file tree
Hide file tree
Showing 17 changed files with 332 additions and 222 deletions.
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def coverage(session) -> None:
"""
session.install(".[test]")
session.run(
"pytest", "--cov=./", "--cov-report=xml", *session.posargs
"pytest", "--slow-last", "--cov=./", "--cov-report=xml", *session.posargs
)


Expand Down
4 changes: 1 addition & 3 deletions src/vak/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@


__title__ = "vak"
__summary__ = (
"A neural network framework for researchers studying acoustic communication"
)
__summary__ = "A neural network framework for researchers studying acoustic communication"
__uri__ = "https://github.com/NickleDave/vak"

__version__ = "1.0.0.post2"
Expand Down
6 changes: 5 additions & 1 deletion src/vak/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ def predict(toml_path):
checkpoint_path=cfg.predict.checkpoint_path,
labelmap_path=cfg.predict.labelmap_path,
num_workers=cfg.predict.num_workers,
timebins_key=cfg.prep.spect_params.timebins_key if cfg.prep else common.constants.TIMEBINS_KEY,
timebins_key=(
cfg.prep.spect_params.timebins_key
if cfg.prep
else common.constants.TIMEBINS_KEY
),
frames_standardizer_path=cfg.predict.frames_standardizer_path,
annot_csv_filename=cfg.predict.annot_csv_filename,
output_dir=cfg.predict.output_dir,
Expand Down
2 changes: 1 addition & 1 deletion src/vak/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@

VALID_SPLITS = ("predict", "test", "train", "val")

DEFAULT_BACKGROUND_LABEL = "background"
DEFAULT_BACKGROUND_LABEL = "background"
4 changes: 3 additions & 1 deletion src/vak/common/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@


def to_map(
labelset: set, map_background: bool = True, background_label: str = constants.DEFAULT_BACKGROUND_LABEL
labelset: set,
map_background: bool = True,
background_label: str = constants.DEFAULT_BACKGROUND_LABEL,
) -> dict:
"""Convert set of labels to `dict`
mapping those labels to a series of consecutive integers
Expand Down
4 changes: 1 addition & 3 deletions src/vak/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,4 @@
]

# TODO: make this a proper registry
DATASETS = {
"BioSoundSegBench": BioSoundSegBench
}
DATASETS = {"BioSoundSegBench": BioSoundSegBench}
172 changes: 98 additions & 74 deletions src/vak/datasets/biosoundsegbench.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""Class representing BioSoundSegBench dataset."""

from __future__ import annotations

import json
import pathlib
from typing import Callable, Literal, TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, Literal

from attrs import define
import numpy as np
import pandas as pd

import torch
import torchvision.transforms
from attrs import define

from .. import common, datapipes, transforms

Expand Down Expand Up @@ -113,6 +113,7 @@ class TrainingReplicateMetadata:
pre-defined training replicate
in the BioSoundSegBench dataset.
"""

biosound_group: str
id: str | None
frame_dur: float
Expand All @@ -123,51 +124,57 @@ class TrainingReplicateMetadata:


def metadata_from_splits_json_path(
splits_json_path: pathlib.Path, datset_path: pathlib.Path
) -> TrainingReplicateMetadata:
splits_json_path: pathlib.Path, datset_path: pathlib.Path
) -> TrainingReplicateMetadata:
name = splits_json_path.name
try:
# Human-Speech doesn't have ID or data source in filename
# so it will raise a ValueError
name = splits_json_path.name
(biosound_group,
id_,
timebin_dur_1st_half,
timebin_dur_2nd_half,
unit,
data_source,
train_dur_1st_half,
train_dur_2nd_half,
replicate_num,
_, _
) = name.split('.')
(
biosound_group,
unit,
id_,
frame_dur_1st_half,
frame_dur_2nd_half,
data_source,
train_dur_1st_half,
train_dur_2nd_half,
replicate_num,
_,
_,
) = name.split(".")
# Human-Speech doesn't have ID or data source in filename
# so it will raise a ValueError
except ValueError:
name = splits_json_path.name
(biosound_group,
timebin_dur_1st_half,
timebin_dur_2nd_half,
unit,
train_dur_1st_half,
train_dur_2nd_half,
replicate_num,
_, _
) = name.split('.')
(
biosound_group,
unit,
frame_dur_1st_half,
frame_dur_2nd_half,
train_dur_1st_half,
train_dur_2nd_half,
replicate_num,
_,
_,
) = name.split(".")
id_ = None
data_source = None
if id_ is not None:
id_ = id_.split('-')[-1]
timebin_dur = float(
timebin_dur_1st_half.split('-')[-1] + '.' + timebin_dur_2nd_half.split('-')[0]
id_ = id_.split("-")[-1]
frame_dur = float(
frame_dur_1st_half.split("-")[-1]
+ "."
+ frame_dur_2nd_half.split("-")[0]
)
train_dur = float(
train_dur_1st_half.split('-')[-1] + '.' + train_dur_2nd_half.split('-')[0]
)
replicate_num = int(
replicate_num.split('-')[-1]
train_dur_1st_half.split("-")[-1]
+ "."
+ train_dur_2nd_half.split("-")[0]
)
replicate_num = int(replicate_num.split("-")[-1])
return TrainingReplicateMetadata(
biosound_group,
id_,
timebin_dur,
frame_dur,
unit,
data_source,
train_dur,
Expand All @@ -184,10 +191,9 @@ def __init__(
frames_standardizer: FramesStandardizer | None = None,
):
from ..transforms import FramesStandardizer # avoid circular import

if frames_standardizer is not None:
if isinstance(
frames_standardizer, FramesStandardizer
):
if isinstance(frames_standardizer, FramesStandardizer):
frames_transform = [frames_standardizer]
else:
raise TypeError(
Expand All @@ -211,24 +217,30 @@ def __init__(
self.frame_labels_transform = transforms.ToLongTensor()

def __call__(
self,
frames: torch.Tensor,
multi_frame_labels: torch.Tensor | None = None,
binary_frame_labels: torch.Tensor | None = None,
boundary_frame_labels: torch.Tensor | None = None,
) -> dict:
self,
frames: torch.Tensor,
multi_frame_labels: torch.Tensor | None = None,
binary_frame_labels: torch.Tensor | None = None,
boundary_frame_labels: torch.Tensor | None = None,
) -> dict:
frames = self.frames_transform(frames)
item = {
"frames": frames,
}
if multi_frame_labels is not None:
item["multi_frame_labels"] = self.frame_labels_transform(multi_frame_labels)
item["multi_frame_labels"] = self.frame_labels_transform(
multi_frame_labels
)

if binary_frame_labels is not None:
item["binary_frame_labels"] = self.frame_labels_transform(binary_frame_labels)
item["binary_frame_labels"] = self.frame_labels_transform(
binary_frame_labels
)

if boundary_frame_labels is not None:
item["boundary_frame_labels"] = self.frame_labels_transform(boundary_frame_labels)
item["boundary_frame_labels"] = self.frame_labels_transform(
boundary_frame_labels
)

return item

Expand Down Expand Up @@ -285,9 +297,7 @@ def __init__(
self.channel_dim = channel_dim

if frames_standardizer is not None:
if not isinstance(
frames_standardizer, FramesStandardizer
):
if not isinstance(frames_standardizer, FramesStandardizer):
raise TypeError(
f"Invalid type for frames_standardizer: {type(frames_standardizer)}. "
"Should be an instance of vak.transforms.FramesStandardizer"
Expand Down Expand Up @@ -335,13 +345,19 @@ def __call__(
}

if multi_frame_labels is not None:
item["multi_frame_labels"] = self.frame_labels_transform(multi_frame_labels)
item["multi_frame_labels"] = self.frame_labels_transform(
multi_frame_labels
)

if binary_frame_labels is not None:
item["binary_frame_labels"] = self.frame_labels_transform(binary_frame_labels)
item["binary_frame_labels"] = self.frame_labels_transform(
binary_frame_labels
)

if boundary_frame_labels is not None:
item["boundary_frame_labels"] = self.frame_labels_transform(boundary_frame_labels)
item["boundary_frame_labels"] = self.frame_labels_transform(
boundary_frame_labels
)

if padding_mask is not None:
item["padding_mask"] = padding_mask
Expand All @@ -355,6 +371,7 @@ def __call__(

class BioSoundSegBench:
"""Class representing BioSoundSegBench dataset."""

def __init__(
self,
dataset_path: str | pathlib.Path,
Expand All @@ -369,7 +386,7 @@ def __init__(
frame_labels_padval: int = -1,
return_padding_mask: bool = False,
return_frames_path: bool = False,
item_transform: Callable | None = None
item_transform: Callable | None = None,
):
"""BioSoundSegBench dataset."""
# ---- validate args, roughly in order
Expand All @@ -387,7 +404,9 @@ def __init__(

splits_path = pathlib.Path(splits_path)
if not splits_path.exists():
tmp_splits_path = dataset_path / "splits" / "splits-jsons" / splits_path
tmp_splits_path = (
dataset_path / "splits" / "splits-jsons" / splits_path
)
if not tmp_splits_path.exists():
raise FileNotFoundError(
f"Did not find `splits_path` using either absolute path ({splits_path})"
Expand All @@ -413,10 +432,9 @@ def __init__(
f"Valid `target_type` arguments are: {VALID_TARGET_TYPES}"
)
if isinstance(target_type, (list, tuple)):
if not all([
isinstance(target_type_, str)
for target_type_ in target_type
]):
if not all(
[isinstance(target_type_, str) for target_type_ in target_type]
):
types_in_target_types = set(
[type(target_type_) for target_type_ in target_type]
)
Expand All @@ -443,13 +461,15 @@ def __init__(
self.training_replicate_metadata = metadata_from_splits_json_path(
self.splits_path, self.dataset_path
)
self.frame_dur = self.training_replicate_metadata.frame_dur * 1e-3 # convert from ms to s!
self.frame_dur = (
self.training_replicate_metadata.frame_dur * 1e-3
) # convert from ms to s!

if "multi_frame_labels" in target_type:
labelmaps_json_path = self.dataset_path / "labelmaps.json"
if not labelmaps_json_path.exists():
raise FileNotFoundError(
"`target_type` includes \"multi_frame_labels\" but "
'`target_type` includes "multi_frame_labels" but '
"'labelmaps.json' was not found in root of dataset path:\n"
f"{labelmaps_json_path}"
)
Expand All @@ -472,10 +492,10 @@ def __init__(
f"group '{group}', unit '{unit}', and id '{id}'. "
"Please check that splits_json path is correct."
)
elif target_type == ('binary_frame_labels',):
self.labelmap = {'no segment': 0, 'segment': 1}
elif target_type == ('boundary_frame_labels',):
self.labelmap = {'no boundary': 0, 'boundary': 1}
elif target_type == ("binary_frame_labels",):
self.labelmap = {"no segment": 0, "segment": 1}
elif target_type == ("boundary_frame_labels",):
self.labelmap = {"no boundary": 0, "boundary": 1}

self.split = split
split_df = pd.read_csv(self.splits_metadata.splits_csv_path)
Expand Down Expand Up @@ -508,15 +528,20 @@ def __init__(
self.inds_in_sample = np.load(
getattr(self.splits_metadata.inds_in_sample_vector_paths, split)
)
self.window_inds = datapipes.frame_classification.train_datapipe.get_window_inds(
self.sample_ids.shape[-1], window_size, stride
self.window_inds = (
datapipes.frame_classification.train_datapipe.get_window_inds(
self.sample_ids.shape[-1], window_size, stride
)
)

if item_transform is None:
if standardize_frames and frames_standardizer is None:
from ..transforms import FramesStandardizer
frames_standardizer = FramesStandardizer.fit_inputs_targets_csv_path(
self.splits_metadata.splits_csv_path, self.dataset_path

frames_standardizer = (
FramesStandardizer.fit_inputs_targets_csv_path(
self.splits_metadata.splits_csv_path, self.dataset_path
)
)
if split == "train":
self.item_transform = TrainItemTransform(
Expand Down Expand Up @@ -580,7 +605,8 @@ def _getitem_train(self, idx):
item["frames"] = spect_dict[common.constants.SPECT_KEY]
for target_type in self.target_type:
item[target_type] = np.load(
self.dataset_path / self.target_paths[target_type][sample_id]
self.dataset_path
/ self.target_paths[target_type][sample_id]
)

elif len(uniq_sample_ids) > 1:
Expand All @@ -592,9 +618,7 @@ def _getitem_train(self, idx):
for sample_id in sorted(uniq_sample_ids):
frames_path = self.dataset_path / self.frames_paths[sample_id]
spect_dict = common.files.spect.load(frames_path)
item["frames"].append(
spect_dict[common.constants.SPECT_KEY]
)
item["frames"].append(spect_dict[common.constants.SPECT_KEY])
for target_type in self.target_type:
item[target_type].append(
np.load(
Expand Down
Loading

0 comments on commit 3763838

Please sign in to comment.