Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: fix frame classification model to work with BioSoundSegBench #774

Merged
merged 24 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3a1a827
Hack a learning rate scheduler into FrameClassificationModel
NickleDave Jul 9, 2024
171fefe
Add boundary_labels parameter to transforms.frame_labels.transforms.P…
NickleDave Jul 17, 2024
e38f982
Add background_label to post_tfm_kwargs in src/vak/eval/frame_classif…
NickleDave Jul 17, 2024
6f5b609
In `validation_step` of FrameClassificationModel, use boundary labels…
NickleDave Jul 17, 2024
d3f00dc
Unpack `dataset_path` from dataset_config in code block for built-in …
NickleDave Jul 18, 2024
b5e6d49
Make variable `frame_dur` inside code block for built-in datasets ins…
NickleDave Jul 18, 2024
5676d85
Pass `background_label` into transforms.frame_labels.PostProcess insi…
NickleDave Jul 18, 2024
57ab70f
Fix how we call self.manual_backwward in FrameClassificationModel to …
NickleDave Jul 18, 2024
a948b62
In FrameClassificationModel.validation_step, convert boundary_preds t…
NickleDave Jul 18, 2024
bc40abc
In FrameClassificationModel.validation_step, when logging accuracy, c…
NickleDave Jul 18, 2024
36c67bf
Change how we get and log frame_dur in train/frame_classification.py …
NickleDave Jul 18, 2024
fd20e45
Change one-line summary of __call__ method for frame_labels.transform…
NickleDave Jul 18, 2024
2164f35
BUG: Ensure boundary_labels is 1d in post-process transform, fix #767
NickleDave Jul 19, 2024
f5d7046
Fix what metric we use for learning rate scheduler: use val_multi_acc…
NickleDave Sep 4, 2024
b055a07
Remove trainer module from common, code is used only for frame classi…
NickleDave Sep 4, 2024
1f83de8
Add get_trainer and get_callbacks to train/frame_classification.py, f…
NickleDave Sep 4, 2024
2452922
Add missing self.manual_backward in training_step of FrameClassificat…
NickleDave Sep 5, 2024
f40a3d4
Fix how we determine whether there are multiple targets and what to m…
NickleDave Sep 5, 2024
8f98f1c
Fix how we validate boundary_labels in transforms.frame_labels.functi…
NickleDave Sep 6, 2024
6111356
Fix vak/predict/frame_classification.py to handle edge case where no …
NickleDave Sep 6, 2024
5f0da8d
Revise comment
NickleDave Sep 7, 2024
02974d3
Catch edge case in transforms.frame_labels.functional.boundary_inds_f…
NickleDave Sep 7, 2024
a07f17b
Add minimal unit tests for vak.transforms.frame_labels.functional.bou…
NickleDave Sep 9, 2024
608c0fc
Remove learning rate scheduler for now
NickleDave Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/vak/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
tensorboard,
timebins,
timenow,
trainer,
typing,
validators,
)
Expand All @@ -39,7 +38,6 @@
"tensorboard",
"timebins",
"timenow",
"trainer",
"typing",
"validators",
]
88 changes: 0 additions & 88 deletions src/vak/common/trainer.py

This file was deleted.

11 changes: 10 additions & 1 deletion src/vak/eval/frame_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.utils.data

from .. import datapipes, datasets, models, transforms
from ..common import validators
from ..common import constants, validators
from ..datapipes.frame_classification import InferDatapipe

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -154,12 +154,20 @@ def eval_frame_classification_model(
)
# ---- *yes* using a built-in dataset ------------------------------------------------------------------------------
else:
# next line, we don't use dataset path in this code block,
# but we need it below when we build the DataFrame with eval results.
# we're unpacking it here just as we do above with a prep'd dataset
dataset_path = pathlib.Path(dataset_config["path"])
dataset_config["params"]["return_padding_mask"] = True
val_dataset = datasets.get(
dataset_config,
split=split,
frames_standardizer=frames_standardizer,
)
frame_dur = val_dataset.frame_dur
logger.info(
f"Duration of a frame in dataset, in seconds: {frame_dur}",
)

val_loader = torch.utils.data.DataLoader(
dataset=val_dataset,
Expand All @@ -179,6 +187,7 @@ def eval_frame_classification_model(
if post_tfm_kwargs:
post_tfm = transforms.frame_labels.PostProcess(
timebin_dur=frame_dur,
background_label=labelmap[constants.DEFAULT_BACKGROUND_LABEL],
**post_tfm_kwargs,
)
else:
Expand Down
19 changes: 12 additions & 7 deletions src/vak/models/frame_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def __init__(
:const:`vak.common.constants.DEFAULT_BACKGROUND_LABEL`.
"""
super().__init__()

self.network = network
self.loss = loss
self.optimizer = optimizer
Expand Down Expand Up @@ -365,9 +364,15 @@ def validation_step(self, batch: tuple, batch_idx: int):
class_preds_str = self.to_labels_eval(class_preds.cpu().numpy())

if self.post_tfm:
class_preds_tfm = self.post_tfm(
class_preds.cpu().numpy(),
)
if target_types == ("multi_frame_labels",):
class_preds_tfm = self.post_tfm(
class_preds.cpu().numpy(),
)
elif target_types == ("multi_frame_labels", "boundary_frame_labels"):
class_preds_tfm = self.post_tfm(
class_preds.cpu().numpy(),
boundary_labels=boundary_preds.cpu().numpy(),
)
class_preds_tfm_str = self.to_labels_eval(class_preds_tfm)
# convert back to tensor so we can compute accuracy
class_preds_tfm = torch.from_numpy(class_preds_tfm).to(
Expand Down Expand Up @@ -395,8 +400,8 @@ def validation_step(self, batch: tuple, batch_idx: int):
loss = self.loss(
class_logits,
boundary_logits,
batch["multi_frame_labels"],
batch["boundary_frame_labels"],
target["multi_frame_labels"],
target["boundary_frame_labels"],
)
if isinstance(loss, torch.Tensor):
self.log(
Expand Down Expand Up @@ -435,7 +440,7 @@ def validation_step(self, batch: tuple, batch_idx: int):
)
else:
self.log(
f"val_{metric_name}",
f"val_multi_{metric_name}",
metric_callable(
class_preds, target["multi_frame_labels"]
),
Expand Down
7 changes: 5 additions & 2 deletions src/vak/predict/frame_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,11 @@ def predict_with_frame_classification_model(
annot_path=annot_csv_path.name,
)
annots.append(annot)

if all([isinstance(annot, crowsetta.Annotation) for annot in annots]):
if len(annots) < 1:
# catch edge case where nothing was predicted
# FIXME: this should have columns that match GenericSeq
pd.DataFrame.from_records([]).to_csv(annot_csv_path)
elif all([isinstance(annot, crowsetta.Annotation) for annot in annots]):
generic_seq = crowsetta.formats.seq.GenericSeq(annots=annots)
generic_seq.to_file(annot_path=annot_csv_path)
elif all([isinstance(annot, AnnotationDataFrame) for annot in annots]):
Expand Down
117 changes: 108 additions & 9 deletions src/vak/train/frame_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import pathlib
import shutil

import lightning
import joblib
import pandas as pd
import torch.utils.data

from .. import datapipes, datasets, models, transforms
from ..common import validators
from ..common.trainer import get_default_trainer
from ..datapipes.frame_classification import InferDatapipe, TrainDatapipe

logger = logging.getLogger(__name__)
Expand All @@ -25,6 +25,92 @@ def get_split_dur(df: pd.DataFrame, split: str) -> float:
return df[df["split"] == split]["duration"].sum()


def get_train_callbacks(
ckpt_root: str | pathlib.Path,
ckpt_step: int,
patience: int,
checkpoint_monitor: str = "val_acc",
early_stopping_monitor: str = "val_acc",
early_stopping_mode: str = "max",
) -> list[lightning.pytorch.callbacks.Callback]:
ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint(
dirpath=ckpt_root,
filename="checkpoint",
every_n_train_steps=ckpt_step,
save_last=True,
verbose=True,
)
ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint"
ckpt_callback.FILE_EXTENSION = ".pt"

val_ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint(
monitor=checkpoint_monitor,
dirpath=ckpt_root,
save_top_k=1,
mode="max",
filename="max-val-acc-checkpoint",
auto_insert_metric_name=False,
verbose=True,
)
val_ckpt_callback.FILE_EXTENSION = ".pt"

early_stopping = lightning.pytorch.callbacks.EarlyStopping(
mode=early_stopping_mode,
monitor=early_stopping_monitor,
patience=patience,
verbose=True,
)

return [ckpt_callback, val_ckpt_callback, early_stopping]


def get_trainer(
accelerator: str,
devices: int | list[int],
max_steps: int,
log_save_dir: str | pathlib.Path,
val_step: int,
callback_kwargs: dict | None = None,
) -> lightning.pytorch.Trainer:
"""Returns an instance of :class:`lightning.pytorch.Trainer`
with a default set of callbacks.

Used by :func:`vak.train.frame_classification`.
The default set of callbacks is provided by
:func:`get_default_train_callbacks`.

Parameters
----------
accelerator : str
devices : int, list of int
max_steps : int
log_save_dir : str, pathlib.Path
val_step : int
default_callback_kwargs : dict, optional

Returns
-------
trainer : lightning.pytorch.Trainer

"""
if callback_kwargs:
callbacks = get_train_callbacks(**callback_kwargs)
else:
callbacks = None

logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=log_save_dir)

trainer = lightning.pytorch.Trainer(
accelerator=accelerator,
devices=devices,
callbacks=callbacks,
val_check_interval=val_step,
max_steps=max_steps,
logger=logger,
)
return trainer


def train_frame_classification_model(
model_config: dict,
dataset_config: dict,
Expand Down Expand Up @@ -245,8 +331,9 @@ def train_frame_classification_model(
dataset_config,
split="train",
)
frame_dur = train_dataset.frame_dur
logger.info(
f"Duration of a frame in dataset, in seconds: {train_dataset.frame_dur}",
f"Duration of a frame in dataset, in seconds: {frame_dur}",
)
# copy labelmap from dataset to new results_path
labelmap = train_dataset.labelmap
Expand Down Expand Up @@ -334,18 +421,30 @@ def train_frame_classification_model(
ckpt_root.mkdir()
logger.info(f"training {model_name}")
max_steps = num_epochs * len(train_loader)
default_callback_kwargs = {
"ckpt_root": ckpt_root,
"ckpt_step": ckpt_step,
"patience": patience,
}
trainer = get_default_trainer(
if isinstance(dataset_config["params"]["target_type"], list) and all([isinstance(target_type, str) for target_type in dataset_config["params"]["target_type"]]):
multiple_targets = True
elif isinstance(dataset_config["params"]["target_type"], str):
multiple_targets = False
else:
raise ValueError(
f'Invalid value for dataset_config["params"]["target_type"]: {dataset_config["params"]["target_type"], list}'
)

callback_kwargs = dict(
ckpt_root=ckpt_root,
ckpt_step=ckpt_step,
patience=patience,
checkpoint_monitor="val_multi_acc" if multiple_targets else "val_acc",
early_stopping_monitor="val_multi_acc" if multiple_targets else "val_acc",
early_stopping_mode="max",
)
trainer = get_trainer(
accelerator=trainer_config["accelerator"],
devices=trainer_config["devices"],
max_steps=max_steps,
log_save_dir=results_model_root,
val_step=val_step,
default_callback_kwargs=default_callback_kwargs,
callback_kwargs=callback_kwargs,
)
train_time_start = datetime.datetime.now()
logger.info(f"Training start time: {train_time_start.isoformat()}")
Expand Down
14 changes: 11 additions & 3 deletions src/vak/transforms/frame_labels/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,17 @@ def boundary_inds_from_boundary_labels(
If ``True``, and the first index of ``boundary_labels`` is not classified as a boundary,
force it to be a boundary.
"""
boundary_labels = row_or_1d(boundary_labels)
boundary_inds = np.nonzero(boundary_labels)[0]

if boundary_inds[0] != 0 and force_boundary_first_ind:
# force there to be a boundary at index 0
np.insert(boundary_inds, 0, 0)
if force_boundary_first_ind:
if len(boundary_inds) == 0:
# handle edge case where no boundaries were predicted
boundary_inds = np.array([0]) # replace with a single boundary, at index 0
else:
if boundary_inds[0] != 0:
# force there to be a boundary at index 0
np.insert(boundary_inds, 0, 0)

return boundary_inds

Expand Down Expand Up @@ -531,6 +537,8 @@ def postprocess(
Vector of frame labels after post-processing is applied.
"""
frame_labels = row_or_1d(frame_labels)
if boundary_labels is not None:
boundary_labels = row_or_1d(boundary_labels)

# handle the case when all time bins are predicted to be unlabeled
# see https://github.com/NickleDave/vak/issues/383
Expand Down
Loading
Loading