From 3a1a827efbb7851471cf78c681ba6eb54514af92 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Tue, 9 Jul 2024 10:16:24 -0400 Subject: [PATCH 01/24] Hack a learning rate scheduler into FrameClassificationModel --- src/vak/models/frame_classification_model.py | 24 ++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index 2da299cf3..8f44f861b 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -130,6 +130,7 @@ def __init__( :const:`vak.common.constants.DEFAULT_BACKGROUND_LABEL`. """ super().__init__() + self.automatic_optimization = False # so we can use learning rate scheduler self.network = network self.loss = loss @@ -173,8 +174,13 @@ def configure_optimizers(self): If None was passed in, an instance that was created with default arguments will be returned. """ - return self.optimizer - + optimizer = torch.optim.Adam(self.network.parameters(), lr=1e-3) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=4) + return { + "optimizer": optimizer, + "lr_scheduler": scheduler + } + def forward(self, x: torch.Tensor) -> torch.Tensor: """Run a forward pass through this model's network. @@ -209,6 +215,8 @@ def training_step(self, batch: tuple, batch_idx: int): the loss function, ``self.loss``. """ frames = batch["frames"] + opt = self.optimizers() + opt.zero_grad() # we repeat this code in validation step # because I'm assuming it's faster than a call to a staticmethod that factors it out @@ -256,6 +264,9 @@ def training_step(self, batch: tuple, batch_idx: int): for loss_name, loss_val in loss.items(): self.log(f"train_{loss_name}", loss_val, on_step=True) + self.manual_backward(loss) + opt.step() + return loss def validation_step(self, batch: tuple, batch_idx: int): @@ -492,6 +503,15 @@ def validation_step(self, batch: tuple, batch_idx: int): sync_dist=True, ) + def on_validation_end(self): + # adding this method is so we can call learning rate scheduler after computing validation metrics + scheduler = self.lr_schedulers() + if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + scheduler.step(self.trainer.callback_metrics["val_acc"]) + lr = scheduler.get_last_lr() + logger = self.logger.experiment + logger.add_scalar('learning_rate', lr[-1], global_step=self.trainer.global_step) + def predict_step(self, batch: tuple, batch_idx: int): """Perform one prediction step. From 171fefe8118054bd86362396029879a9ca9fed31 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 17 Jul 2024 08:05:18 -0400 Subject: [PATCH 02/24] Add boundary_labels parameter to transforms.frame_labels.transforms.PostProcess.__call__ --- src/vak/transforms/frame_labels/transforms.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/vak/transforms/frame_labels/transforms.py b/src/vak/transforms/frame_labels/transforms.py index 0dab0c504..ededfe543 100644 --- a/src/vak/transforms/frame_labels/transforms.py +++ b/src/vak/transforms/frame_labels/transforms.py @@ -24,6 +24,7 @@ from __future__ import annotations import numpy as np +import numpy.typing as npt from . import functional as F @@ -258,7 +259,7 @@ def __init__( self.min_segment_dur = min_segment_dur self.majority_vote = majority_vote - def __call__(self, frame_labels: np.ndarray) -> np.ndarray: + def __call__(self, frame_labels: np.ndarray, boundary_labels: npt.NDArray | None = None) -> np.ndarray: """Convert vector of frame labels into labels. Parameters @@ -280,4 +281,5 @@ def __call__(self, frame_labels: np.ndarray) -> np.ndarray: self.background_label, self.min_segment_dur, self.majority_vote, + boundary_labels, ) From e38f9820d2388afccf9c8f571957d4754eb0f9e8 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 17 Jul 2024 08:05:46 -0400 Subject: [PATCH 03/24] Add background_label to post_tfm_kwargs in src/vak/eval/frame_classification.py --- src/vak/eval/frame_classification.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/vak/eval/frame_classification.py b/src/vak/eval/frame_classification.py index 12688288f..7b8f1ddb6 100644 --- a/src/vak/eval/frame_classification.py +++ b/src/vak/eval/frame_classification.py @@ -14,7 +14,7 @@ import torch.utils.data from .. import datapipes, datasets, models, transforms -from ..common import validators +from ..common import constants, validators from ..datapipes.frame_classification import InferDatapipe logger = logging.getLogger(__name__) @@ -179,6 +179,7 @@ def eval_frame_classification_model( if post_tfm_kwargs: post_tfm = transforms.frame_labels.PostProcess( timebin_dur=frame_dur, + background_label=labelmap[constants.DEFAULT_BACKGROUND_LABEL] **post_tfm_kwargs, ) else: From 6f5b6098439d1aa315ca9dd4c2810376312ba6dd Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 17 Jul 2024 08:07:30 -0400 Subject: [PATCH 04/24] In `validation_step` of FrameClassificationModel, use boundary labels when they are present to post-process multi-class frame labels' --- src/vak/models/frame_classification_model.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index 8f44f861b..da89d7aa6 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -376,9 +376,15 @@ def validation_step(self, batch: tuple, batch_idx: int): class_preds_str = self.to_labels_eval(class_preds.cpu().numpy()) if self.post_tfm: - class_preds_tfm = self.post_tfm( - class_preds.cpu().numpy(), - ) + if target_types == ("multi_frame_labels",): + class_preds_tfm = self.post_tfm( + class_preds.cpu().numpy(), + ) + elif target_types == ("multi_frame_labels", "boundary_frame_labels"): + class_preds_tfm = self.post_tfm( + class_preds.cpu().numpy(), + boundary_labels=boundary_preds, + ) class_preds_tfm_str = self.to_labels_eval(class_preds_tfm) # convert back to tensor so we can compute accuracy class_preds_tfm = torch.from_numpy(class_preds_tfm).to( @@ -406,8 +412,8 @@ def validation_step(self, batch: tuple, batch_idx: int): loss = self.loss( class_logits, boundary_logits, - batch["multi_frame_labels"], - batch["boundary_frame_labels"], + target["multi_frame_labels"], + target["boundary_frame_labels"], ) if isinstance(loss, torch.Tensor): self.log( From d3f00dc2c8bfc4faa775638f5832ea02bde72c99 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 18 Jul 2024 10:29:12 -0400 Subject: [PATCH 05/24] Unpack `dataset_path` from dataset_config in code block for built-in datasets in eval/frame_classification.py, to make sure this variable exists when we build the DataFrame with eval results" --- src/vak/eval/frame_classification.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/vak/eval/frame_classification.py b/src/vak/eval/frame_classification.py index 7b8f1ddb6..1e0513905 100644 --- a/src/vak/eval/frame_classification.py +++ b/src/vak/eval/frame_classification.py @@ -154,6 +154,10 @@ def eval_frame_classification_model( ) # ---- *yes* using a built-in dataset ------------------------------------------------------------------------------ else: + # next line, we don't use dataset path in this code block, + # but we need it below when we build the DataFrame with eval results. + # we're unpacking it here just as we do above with a prep'd dataset + dataset_path = pathlib.Path(dataset_config["path"]) dataset_config["params"]["return_padding_mask"] = True val_dataset = datasets.get( dataset_config, From b5e6d49c6c5e17f29b6426a542f61635e785aa4e Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 18 Jul 2024 10:29:54 -0400 Subject: [PATCH 06/24] Make variable `frame_dur` inside code block for built-in datasets inside eval/frame_classification.py so this variable exists when we get the post-processing transform --- src/vak/eval/frame_classification.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/vak/eval/frame_classification.py b/src/vak/eval/frame_classification.py index 1e0513905..1ea42bbaa 100644 --- a/src/vak/eval/frame_classification.py +++ b/src/vak/eval/frame_classification.py @@ -164,6 +164,10 @@ def eval_frame_classification_model( split=split, frames_standardizer=frames_standardizer, ) + frame_dur = val_dataset.frame_dur + logger.info( + f"Duration of a frame in dataset, in seconds: {frame_dur}", + ) val_loader = torch.utils.data.DataLoader( dataset=val_dataset, From 5676d85adc484ec0116db8d23c7f702a4bd45fef Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 18 Jul 2024 10:30:55 -0400 Subject: [PATCH 07/24] Pass `background_label` into transforms.frame_labels.PostProcess inside eval/frame_classification.py, using `constants.DEFAULT_BACKGROUND_LABEL` --- src/vak/eval/frame_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vak/eval/frame_classification.py b/src/vak/eval/frame_classification.py index 1ea42bbaa..665a0bb4f 100644 --- a/src/vak/eval/frame_classification.py +++ b/src/vak/eval/frame_classification.py @@ -187,7 +187,7 @@ def eval_frame_classification_model( if post_tfm_kwargs: post_tfm = transforms.frame_labels.PostProcess( timebin_dur=frame_dur, - background_label=labelmap[constants.DEFAULT_BACKGROUND_LABEL] + background_label=labelmap[constants.DEFAULT_BACKGROUND_LABEL], **post_tfm_kwargs, ) else: From 57ab70febfd81e35e426ea4090a5bc7c841c1786 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 18 Jul 2024 10:32:13 -0400 Subject: [PATCH 08/24] Fix how we call self.manual_backwward in FrameClassificationModel to handle the case when the loss function returns a dict --- src/vak/models/frame_classification_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index da89d7aa6..6cbf21ed7 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -259,12 +259,13 @@ def training_step(self, batch: tuple, batch_idx: int): ) if isinstance(loss, torch.Tensor): self.log("train_loss", loss, on_step=True) + self.manual_backward(loss) elif isinstance(loss, dict): # this provides a mechanism to values for all terms of a loss function with multiple terms for loss_name, loss_val in loss.items(): self.log(f"train_{loss_name}", loss_val, on_step=True) + self.manual_backward(loss["loss"]) - self.manual_backward(loss) opt.step() return loss From a948b6258e0296b04b6eac7e89e53d97931eb9e0 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 18 Jul 2024 10:32:54 -0400 Subject: [PATCH 09/24] In FrameClassificationModel.validation_step, convert boundary_preds to numpy when we pass them in to self.post_tfm --- src/vak/models/frame_classification_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index 6cbf21ed7..94580ee83 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -384,7 +384,7 @@ def validation_step(self, batch: tuple, batch_idx: int): elif target_types == ("multi_frame_labels", "boundary_frame_labels"): class_preds_tfm = self.post_tfm( class_preds.cpu().numpy(), - boundary_labels=boundary_preds, + boundary_labels=boundary_preds.cpu().numpy(), ) class_preds_tfm_str = self.to_labels_eval(class_preds_tfm) # convert back to tensor so we can compute accuracy From bc40abc57cff6732d193bef72e37df5095a3ee76 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 18 Jul 2024 10:33:53 -0400 Subject: [PATCH 10/24] In FrameClassificationModel.validation_step, when logging accuracy, call it 'val_multi_acc' to distinguish from boundary_acc and for consistency with val_multi_acc_tfm --- src/vak/models/frame_classification_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index 94580ee83..706183203 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -453,7 +453,7 @@ def validation_step(self, batch: tuple, batch_idx: int): ) else: self.log( - f"val_{metric_name}", + f"val_multi_{metric_name}", metric_callable( class_preds, target["multi_frame_labels"] ), From 36c67bf927c6536d8d79f8b91090fca6a02d58cb Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 18 Jul 2024 10:34:45 -0400 Subject: [PATCH 11/24] Change how we get and log frame_dur in train/frame_classification.py so we have it as a separate variable; will use for post_tfm kwargs when we add those later --- src/vak/train/frame_classification.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/vak/train/frame_classification.py b/src/vak/train/frame_classification.py index f83f6518f..a97bc4107 100644 --- a/src/vak/train/frame_classification.py +++ b/src/vak/train/frame_classification.py @@ -245,8 +245,9 @@ def train_frame_classification_model( dataset_config, split="train", ) + frame_dur = train_dataset.frame_dur logger.info( - f"Duration of a frame in dataset, in seconds: {train_dataset.frame_dur}", + f"Duration of a frame in dataset, in seconds: {frame_dur}", ) # copy labelmap from dataset to new results_path labelmap = train_dataset.labelmap From fd20e45650a417f1b535c102ac2837dc3075c169 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 18 Jul 2024 10:35:20 -0400 Subject: [PATCH 12/24] Change one-line summary of __call__ method for frame_labels.transforms.PostProcess --- src/vak/transforms/frame_labels/transforms.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/vak/transforms/frame_labels/transforms.py b/src/vak/transforms/frame_labels/transforms.py index ededfe543..024a767ba 100644 --- a/src/vak/transforms/frame_labels/transforms.py +++ b/src/vak/transforms/frame_labels/transforms.py @@ -260,7 +260,8 @@ def __init__( self.majority_vote = majority_vote def __call__(self, frame_labels: np.ndarray, boundary_labels: npt.NDArray | None = None) -> np.ndarray: - """Convert vector of frame labels into labels. + """Apply post-processing transformations + to a vector of frame labels. Parameters ---------- From 2164f351c92cd1007a08480924a864921bc6dc9f Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 19 Jul 2024 09:12:50 -0400 Subject: [PATCH 13/24] BUG: Ensure boundary_labels is 1d in post-process transform, fix #767 --- src/vak/transforms/frame_labels/functional.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/vak/transforms/frame_labels/functional.py b/src/vak/transforms/frame_labels/functional.py index 09e9eaf2e..ed4730aa2 100644 --- a/src/vak/transforms/frame_labels/functional.py +++ b/src/vak/transforms/frame_labels/functional.py @@ -401,6 +401,7 @@ def boundary_inds_from_boundary_labels( If ``True``, and the first index of ``boundary_labels`` is not classified as a boundary, force it to be a boundary. """ + boundary_labels = row_or_1d(boundary_labels) boundary_inds = np.nonzero(boundary_labels)[0] if boundary_inds[0] != 0 and force_boundary_first_ind: @@ -531,6 +532,7 @@ def postprocess( Vector of frame labels after post-processing is applied. """ frame_labels = row_or_1d(frame_labels) + boundary_labels = row_or_1d(boundary_labels) # handle the case when all time bins are predicted to be unlabeled # see https://github.com/NickleDave/vak/issues/383 From f5d7046ee880a96c172d92f1b1b8996f9f781d8c Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 4 Sep 2024 10:57:11 -0400 Subject: [PATCH 14/24] Fix what metric we use for learning rate scheduler: use val_multi_acc for models with multiple accuracies --- src/vak/models/frame_classification_model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index 706183203..1000fc399 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -514,7 +514,11 @@ def on_validation_end(self): # adding this method is so we can call learning rate scheduler after computing validation metrics scheduler = self.lr_schedulers() if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): - scheduler.step(self.trainer.callback_metrics["val_acc"]) + if "val_multi_acc" in self.trainer.callback_metrics: + # for segnotator, we plateau on multi-class frame accuracy + scheduler.step(self.trainer.callback_metrics["val_multi_acc"]) + else: + scheduler.step(self.trainer.callback_metrics["val_acc"]) lr = scheduler.get_last_lr() logger = self.logger.experiment logger.add_scalar('learning_rate', lr[-1], global_step=self.trainer.global_step) From b055a07ed8282d7509e4a0d721471e2e1122688a Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 4 Sep 2024 11:22:16 -0400 Subject: [PATCH 15/24] Remove trainer module from common, code is used only for frame classification model --- src/vak/common/__init__.py | 2 - src/vak/common/trainer.py | 88 -------------------------------------- 2 files changed, 90 deletions(-) delete mode 100644 src/vak/common/trainer.py diff --git a/src/vak/common/__init__.py b/src/vak/common/__init__.py index c5be9ccfd..84e1190b3 100644 --- a/src/vak/common/__init__.py +++ b/src/vak/common/__init__.py @@ -21,7 +21,6 @@ tensorboard, timebins, timenow, - trainer, typing, validators, ) @@ -39,7 +38,6 @@ "tensorboard", "timebins", "timenow", - "trainer", "typing", "validators", ] diff --git a/src/vak/common/trainer.py b/src/vak/common/trainer.py deleted file mode 100644 index cd12d03d4..000000000 --- a/src/vak/common/trainer.py +++ /dev/null @@ -1,88 +0,0 @@ -from __future__ import annotations - -import pathlib - -import lightning - - -def get_default_train_callbacks( - ckpt_root: str | pathlib.Path, - ckpt_step: int, - patience: int, -): - ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint( - dirpath=ckpt_root, - filename="checkpoint", - every_n_train_steps=ckpt_step, - save_last=True, - verbose=True, - ) - ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint" - ckpt_callback.FILE_EXTENSION = ".pt" - - val_ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint( - monitor="val_acc", - dirpath=ckpt_root, - save_top_k=1, - mode="max", - filename="max-val-acc-checkpoint", - auto_insert_metric_name=False, - verbose=True, - ) - val_ckpt_callback.FILE_EXTENSION = ".pt" - - early_stopping = lightning.pytorch.callbacks.EarlyStopping( - mode="max", - monitor="val_acc", - patience=patience, - verbose=True, - ) - - return [ckpt_callback, val_ckpt_callback, early_stopping] - - -def get_default_trainer( - accelerator: str, - devices: int | list[int], - max_steps: int, - log_save_dir: str | pathlib.Path, - val_step: int, - default_callback_kwargs: dict | None = None, -) -> lightning.pytorch.Trainer: - """Returns an instance of :class:`lightning.pytorch.Trainer` - with a default set of callbacks. - - Used by :func:`vak.train.frame_classification`. - The default set of callbacks is provided by - :func:`get_default_train_callbacks`. - - Parameters - ---------- - accelerator : str - devices : int, list of int - max_steps : int - log_save_dir : str, pathlib.Path - val_step : int - default_callback_kwargs : dict, optional - - Returns - ------- - trainer : lightning.pytorch.Trainer - - """ - if default_callback_kwargs: - callbacks = get_default_train_callbacks(**default_callback_kwargs) - else: - callbacks = None - - logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=log_save_dir) - - trainer = lightning.pytorch.Trainer( - accelerator=accelerator, - devices=devices, - callbacks=callbacks, - val_check_interval=val_step, - max_steps=max_steps, - logger=logger, - ) - return trainer From 1f83de8dceac1358c9c4ec1dd55ff12eb18c5348 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 4 Sep 2024 11:23:05 -0400 Subject: [PATCH 16/24] Add get_trainer and get_callbacks to train/frame_classification.py, fix so that we monitor 'val_multi_acc' when a model has multiple targets, and just 'val_acc' otherwise --- src/vak/train/frame_classification.py | 105 ++++++++++++++++++++++++-- 1 file changed, 97 insertions(+), 8 deletions(-) diff --git a/src/vak/train/frame_classification.py b/src/vak/train/frame_classification.py index a97bc4107..e7f7a1181 100644 --- a/src/vak/train/frame_classification.py +++ b/src/vak/train/frame_classification.py @@ -8,13 +8,13 @@ import pathlib import shutil +import lightning import joblib import pandas as pd import torch.utils.data from .. import datapipes, datasets, models, transforms from ..common import validators -from ..common.trainer import get_default_trainer from ..datapipes.frame_classification import InferDatapipe, TrainDatapipe logger = logging.getLogger(__name__) @@ -25,6 +25,92 @@ def get_split_dur(df: pd.DataFrame, split: str) -> float: return df[df["split"] == split]["duration"].sum() +def get_train_callbacks( + ckpt_root: str | pathlib.Path, + ckpt_step: int, + patience: int, + checkpoint_monitor: str = "val_acc", + early_stopping_monitor: str = "val_acc", + early_stopping_mode: str = "max", +) -> list[lightning.pytorch.callbacks.Callback]: + ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint( + dirpath=ckpt_root, + filename="checkpoint", + every_n_train_steps=ckpt_step, + save_last=True, + verbose=True, + ) + ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint" + ckpt_callback.FILE_EXTENSION = ".pt" + + val_ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint( + monitor=checkpoint_monitor, + dirpath=ckpt_root, + save_top_k=1, + mode="max", + filename="max-val-acc-checkpoint", + auto_insert_metric_name=False, + verbose=True, + ) + val_ckpt_callback.FILE_EXTENSION = ".pt" + + early_stopping = lightning.pytorch.callbacks.EarlyStopping( + mode=early_stopping_mode, + monitor=early_stopping_monitor, + patience=patience, + verbose=True, + ) + + return [ckpt_callback, val_ckpt_callback, early_stopping] + + +def get_trainer( + accelerator: str, + devices: int | list[int], + max_steps: int, + log_save_dir: str | pathlib.Path, + val_step: int, + callback_kwargs: dict | None = None, +) -> lightning.pytorch.Trainer: + """Returns an instance of :class:`lightning.pytorch.Trainer` + with a default set of callbacks. + + Used by :func:`vak.train.frame_classification`. + The default set of callbacks is provided by + :func:`get_default_train_callbacks`. + + Parameters + ---------- + accelerator : str + devices : int, list of int + max_steps : int + log_save_dir : str, pathlib.Path + val_step : int + default_callback_kwargs : dict, optional + + Returns + ------- + trainer : lightning.pytorch.Trainer + + """ + if callback_kwargs: + callbacks = get_train_callbacks(**callback_kwargs) + else: + callbacks = None + + logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=log_save_dir) + + trainer = lightning.pytorch.Trainer( + accelerator=accelerator, + devices=devices, + callbacks=callbacks, + val_check_interval=val_step, + max_steps=max_steps, + logger=logger, + ) + return trainer + + def train_frame_classification_model( model_config: dict, dataset_config: dict, @@ -335,18 +421,21 @@ def train_frame_classification_model( ckpt_root.mkdir() logger.info(f"training {model_name}") max_steps = num_epochs * len(train_loader) - default_callback_kwargs = { - "ckpt_root": ckpt_root, - "ckpt_step": ckpt_step, - "patience": patience, - } - trainer = get_default_trainer( + callback_kwargs = dict( + ckpt_root=ckpt_root, + ckpt_step=ckpt_step, + patience=patience, + checkpoint_monitor="val_multi_acc" if len(dataset_config["params"]["target_type"]) > 1 else "val_acc", + early_stopping_monitor="val_multi_acc" if len(dataset_config["params"]["target_type"]) > 1 else "val_acc", + early_stopping_mode="max", + ) + trainer = get_trainer( accelerator=trainer_config["accelerator"], devices=trainer_config["devices"], max_steps=max_steps, log_save_dir=results_model_root, val_step=val_step, - default_callback_kwargs=default_callback_kwargs, + callback_kwargs=callback_kwargs, ) train_time_start = datetime.datetime.now() logger.info(f"Training start time: {train_time_start.isoformat()}") From 24529226878f3e2e07694e83e270dd9b6018fdfc Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 5 Sep 2024 13:49:07 -0400 Subject: [PATCH 17/24] Add missing self.manual_backward in training_step of FrameClassificationModel --- src/vak/models/frame_classification_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index 1000fc399..1b27061a2 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -249,6 +249,7 @@ def training_step(self, batch: tuple, batch_idx: int): class_logits = self.network(frames) loss = self.loss(class_logits, batch[target_types[0]]) self.log("train_loss", loss, on_step=True) + self.manual_backward(loss) else: multi_logits, boundary_logits = self.network(frames) loss = self.loss( From f40a3d420f5598dba61779b585b46af09d854184 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 5 Sep 2024 13:49:41 -0400 Subject: [PATCH 18/24] Fix how we determine whether there are multiple targets and what to monitor in train/frame_classification.py --- src/vak/train/frame_classification.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/vak/train/frame_classification.py b/src/vak/train/frame_classification.py index e7f7a1181..74f6a269e 100644 --- a/src/vak/train/frame_classification.py +++ b/src/vak/train/frame_classification.py @@ -421,12 +421,21 @@ def train_frame_classification_model( ckpt_root.mkdir() logger.info(f"training {model_name}") max_steps = num_epochs * len(train_loader) + if isinstance(dataset_config["params"]["target_type"], list) and all([isinstance(target_type, str) for target_type in dataset_config["params"]["target_type"]]): + multiple_targets = True + elif isinstance(dataset_config["params"]["target_type"], str): + multiple_targets = False + else: + raise ValueError( + f'Invalid value for dataset_config["params"]["target_type"]: {dataset_config["params"]["target_type"], list}' + ) + callback_kwargs = dict( ckpt_root=ckpt_root, ckpt_step=ckpt_step, patience=patience, - checkpoint_monitor="val_multi_acc" if len(dataset_config["params"]["target_type"]) > 1 else "val_acc", - early_stopping_monitor="val_multi_acc" if len(dataset_config["params"]["target_type"]) > 1 else "val_acc", + checkpoint_monitor="val_multi_acc" if multiple_targets else "val_acc", + early_stopping_monitor="val_multi_acc" if multiple_targets else "val_acc", early_stopping_mode="max", ) trainer = get_trainer( From 8f98f1cb6615098114fffcdff75970502cd04a1b Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 6 Sep 2024 10:47:10 -0400 Subject: [PATCH 19/24] Fix how we validate boundary_labels in transforms.frame_labels.functional.postprocess -- don't if boundary_labels is None --- src/vak/transforms/frame_labels/functional.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/vak/transforms/frame_labels/functional.py b/src/vak/transforms/frame_labels/functional.py index ed4730aa2..6e7adcc7c 100644 --- a/src/vak/transforms/frame_labels/functional.py +++ b/src/vak/transforms/frame_labels/functional.py @@ -532,7 +532,8 @@ def postprocess( Vector of frame labels after post-processing is applied. """ frame_labels = row_or_1d(frame_labels) - boundary_labels = row_or_1d(boundary_labels) + if boundary_labels is not None: + boundary_labels = row_or_1d(boundary_labels) # handle the case when all time bins are predicted to be unlabeled # see https://github.com/NickleDave/vak/issues/383 From 6111356baea0ec6c8f0202af138f1d6a932d383b Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 6 Sep 2024 13:53:25 -0400 Subject: [PATCH 20/24] Fix vak/predict/frame_classification.py to handle edge case where no non-background segments are predicted for any sample in dataset --- src/vak/predict/frame_classification.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/vak/predict/frame_classification.py b/src/vak/predict/frame_classification.py index eff70ef48..148de01ff 100644 --- a/src/vak/predict/frame_classification.py +++ b/src/vak/predict/frame_classification.py @@ -468,8 +468,11 @@ def predict_with_frame_classification_model( annot_path=annot_csv_path.name, ) annots.append(annot) - - if all([isinstance(annot, crowsetta.Annotation) for annot in annots]): + if len(annots) < 1: + # catch edge case where nothing was predicted + # FIXME: this should have columns that match GenericSeq + pd.DataFrame.from_records([]).to_csv(annot_csv_path) + elif all([isinstance(annot, crowsetta.Annotation) for annot in annots]): generic_seq = crowsetta.formats.seq.GenericSeq(annots=annots) generic_seq.to_file(annot_path=annot_csv_path) elif all([isinstance(annot, AnnotationDataFrame) for annot in annots]): From 5f0da8dc75efdadb3cbf7e9abaa11584801fa07a Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 7 Sep 2024 09:52:23 -0400 Subject: [PATCH 21/24] Revise comment --- src/vak/models/frame_classification_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index 1b27061a2..581988157 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -516,7 +516,7 @@ def on_validation_end(self): scheduler = self.lr_schedulers() if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): if "val_multi_acc" in self.trainer.callback_metrics: - # for segnotator, we plateau on multi-class frame accuracy + # for case where we have multiple accuracies, we have scheduler on multi-class frame accuracy scheduler.step(self.trainer.callback_metrics["val_multi_acc"]) else: scheduler.step(self.trainer.callback_metrics["val_acc"]) From 02974d3b86e776481322cc36d04a1b7b7fa15304 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 7 Sep 2024 09:53:01 -0400 Subject: [PATCH 22/24] Catch edge case in transforms.frame_labels.functional.boundary_inds_from_frame_boundary_labels --- src/vak/transforms/frame_labels/functional.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/vak/transforms/frame_labels/functional.py b/src/vak/transforms/frame_labels/functional.py index 6e7adcc7c..580d3229a 100644 --- a/src/vak/transforms/frame_labels/functional.py +++ b/src/vak/transforms/frame_labels/functional.py @@ -404,9 +404,14 @@ def boundary_inds_from_boundary_labels( boundary_labels = row_or_1d(boundary_labels) boundary_inds = np.nonzero(boundary_labels)[0] - if boundary_inds[0] != 0 and force_boundary_first_ind: - # force there to be a boundary at index 0 - np.insert(boundary_inds, 0, 0) + if force_boundary_first_ind: + if len(boundary_inds) == 0: + # handle edge case where no boundaries were predicted + boundary_inds = np.array([0]) # replace with a single boundary, at index 0 + else: + if boundary_inds[0] != 0: + # force there to be a boundary at index 0 + np.insert(boundary_inds, 0, 0) return boundary_inds From a07f17b0b252963950e90b77b8ac68a504a7e5fd Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Mon, 9 Sep 2024 14:44:22 -0400 Subject: [PATCH 23/24] Add minimal unit tests for vak.transforms.frame_labels.functional.boundary_inds_from_boundary_labels --- .../test_frame_labels/test_functional.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/test_transforms/test_frame_labels/test_functional.py b/tests/test_transforms/test_frame_labels/test_functional.py index 9e9013eb6..eed67504f 100644 --- a/tests/test_transforms/test_frame_labels/test_functional.py +++ b/tests/test_transforms/test_frame_labels/test_functional.py @@ -246,6 +246,45 @@ def test_to_segments_real_data( assert np.all(np.abs(annot.seq.offsets_s - offsets_s) < MAX_ABS_DIFF) +@pytest.mark.parametrize( + "boundary_labels, boundary_inds_expected", + [ + ( + np.array([1,0,0,0,1,0,0]), + np.array([0,4]) + ), + ] +) +def test_boundary_inds_from_boundary_labels(boundary_labels, boundary_inds_expected): + boundary_inds = vak.transforms.frame_labels.boundary_inds_from_boundary_labels( + boundary_labels + ) + assert np.array_equal(boundary_inds, boundary_inds_expected) + + +@pytest.mark.parametrize( + "boundary_labels, expected_exception", + [ + # 3-d array should raise a ValueError, needs to be row or 1-d + ( + np.array([[[1,0,0,0,1,0,0]]]), + ValueError + ), + # column vector should raise a ValueError, needs to be row or 1-d + ( + np.array([[1],[0],[0]]), + ValueError + ) + + ] +) +def test_boundary_inds_from_boundary_labels(boundary_labels, expected_exception): + with pytest.raises(expected_exception): + vak.transforms.frame_labels.functional.segment_inds_list_from_boundary_labels( + boundary_labels + ) + + @pytest.mark.parametrize( "frame_labels, seg_inds_list_expected", [ From 608c0fc9e903f1366789379f7d44cf71f6b13fdb Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Mon, 9 Sep 2024 15:07:10 -0400 Subject: [PATCH 24/24] Remove learning rate scheduler for now --- src/vak/models/frame_classification_model.py | 31 ++------------------ 1 file changed, 2 insertions(+), 29 deletions(-) diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index 581988157..2159c014d 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -130,8 +130,6 @@ def __init__( :const:`vak.common.constants.DEFAULT_BACKGROUND_LABEL`. """ super().__init__() - self.automatic_optimization = False # so we can use learning rate scheduler - self.network = network self.loss = loss self.optimizer = optimizer @@ -174,13 +172,8 @@ def configure_optimizers(self): If None was passed in, an instance that was created with default arguments will be returned. """ - optimizer = torch.optim.Adam(self.network.parameters(), lr=1e-3) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=4) - return { - "optimizer": optimizer, - "lr_scheduler": scheduler - } - + return self.optimizer + def forward(self, x: torch.Tensor) -> torch.Tensor: """Run a forward pass through this model's network. @@ -215,8 +208,6 @@ def training_step(self, batch: tuple, batch_idx: int): the loss function, ``self.loss``. """ frames = batch["frames"] - opt = self.optimizers() - opt.zero_grad() # we repeat this code in validation step # because I'm assuming it's faster than a call to a staticmethod that factors it out @@ -249,7 +240,6 @@ def training_step(self, batch: tuple, batch_idx: int): class_logits = self.network(frames) loss = self.loss(class_logits, batch[target_types[0]]) self.log("train_loss", loss, on_step=True) - self.manual_backward(loss) else: multi_logits, boundary_logits = self.network(frames) loss = self.loss( @@ -260,14 +250,10 @@ def training_step(self, batch: tuple, batch_idx: int): ) if isinstance(loss, torch.Tensor): self.log("train_loss", loss, on_step=True) - self.manual_backward(loss) elif isinstance(loss, dict): # this provides a mechanism to values for all terms of a loss function with multiple terms for loss_name, loss_val in loss.items(): self.log(f"train_{loss_name}", loss_val, on_step=True) - self.manual_backward(loss["loss"]) - - opt.step() return loss @@ -511,19 +497,6 @@ def validation_step(self, batch: tuple, batch_idx: int): sync_dist=True, ) - def on_validation_end(self): - # adding this method is so we can call learning rate scheduler after computing validation metrics - scheduler = self.lr_schedulers() - if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): - if "val_multi_acc" in self.trainer.callback_metrics: - # for case where we have multiple accuracies, we have scheduler on multi-class frame accuracy - scheduler.step(self.trainer.callback_metrics["val_multi_acc"]) - else: - scheduler.step(self.trainer.callback_metrics["val_acc"]) - lr = scheduler.get_last_lr() - logger = self.logger.experiment - logger.add_scalar('learning_rate', lr[-1], global_step=self.trainer.global_step) - def predict_step(self, batch: tuple, batch_idx: int): """Perform one prediction step.