diff --git a/doc/api/index.rst b/doc/api/index.rst
index 03048719f..4b3d2c4d2 100644
--- a/doc/api/index.rst
+++ b/doc/api/index.rst
@@ -130,14 +130,16 @@ and dataclasses that represent tables from those files.
:recursive:
config.config
+ config.dataset
config.eval
config.learncurve
+ config.load
config.model
- config.parse
config.predict
config.prep
config.spect_params
config.train
+ config.trainer
config.validators
Datasets
@@ -265,10 +267,10 @@ used by multiple other modules.
:template: module.rst
:recursive:
+ common.accelerator
common.annotation
common.constants
common.converters
- common.device
common.files
common.labels
common.learncurve
diff --git a/doc/conf.py b/doc/conf.py
index 2c3099705..f7b967bf6 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -107,8 +107,8 @@
"announcement":
""" 🚧 vak version 1.0.0 is in development! 🚧
📣 Test out the alpha release: pip install vak==1.0.0a3
. 📣
- For more info, please see
- this forum post.
+ For more info, please see
+ this forum post.
""",
"sidebar_hide_name": True,
"light_css_variables": {
@@ -217,7 +217,8 @@
"python": ("https://docs.python.org/3/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
- "pytorch": ("https://pytorch.org/docs/stable/", None)
+ "pytorch": ("https://pytorch.org/docs/stable/", None),
+ "lightning": ("https://lightning.ai/docs/pytorch/stable/", None),
}
# -- Options for todo extension ----------------------------------------------
diff --git a/doc/toml/gy6or6_eval.toml b/doc/toml/gy6or6_eval.toml
index 71355ed28..41d25a1c7 100644
--- a/doc/toml/gy6or6_eval.toml
+++ b/doc/toml/gy6or6_eval.toml
@@ -43,7 +43,7 @@ batch_size = 11
# num_workers: number of workers to use when loading data with multiprocessing
num_workers = 16
# device: name of device to run model on, one of "cuda", "cpu"
-device = "cuda"
+
# output_dir: directory where output should be saved, as a sub-directory within `output_dir`
output_dir = "/PATH/TO/FOLDER/results/eval"
# dataset_path : path to dataset created by prep
@@ -72,3 +72,10 @@ window_size = 176
# Note we do not specify any options for the model, and just use the defaults
# We need to put this table here though so we know which model we are using
[vak.eval.model.TweetyNet]
+
+# this sub-table configures the `lightning.pytorch.Trainer`
+[vak.eval.trainer]
+# setting to 'gpu' means "train models on 'gpu' (not 'cpu')"
+accelerator = "gpu"
+# use the first GPU (numbering starts from 0)
+devices = [0]
diff --git a/doc/toml/gy6or6_predict.toml b/doc/toml/gy6or6_predict.toml
index c4c89ef73..4bbfa1dd2 100644
--- a/doc/toml/gy6or6_predict.toml
+++ b/doc/toml/gy6or6_predict.toml
@@ -39,7 +39,7 @@ batch_size = 1
# num_workers: number of workers to use when loading data with multiprocessing
num_workers = 4
# device: name of device to run model on, one of "cuda", "cpu"
-device = "cuda"
+
# output_dir: directory where output should be saved, as a sub-directory within `output_dir`
output_dir = "/PATH/TO/FOLDER/results/predict"
# annot_csv_filename
@@ -67,3 +67,10 @@ window_size = 176
# Note we do not specify any options for the network, and just use the defaults
# We need to put this table here though, to indicate which model we are using.
[vak.predict.model.TweetyNet]
+
+# this sub-table configures the `lightning.pytorch.Trainer`
+[vak.predict.trainer]
+# setting to 'gpu' means "train models on 'gpu' (not 'cpu')"
+accelerator = "gpu"
+# use the first GPU (numbering starts from 0)
+devices = [0]
diff --git a/doc/toml/gy6or6_train.toml b/doc/toml/gy6or6_train.toml
index 68202f796..6c802b2ec 100644
--- a/doc/toml/gy6or6_train.toml
+++ b/doc/toml/gy6or6_train.toml
@@ -53,7 +53,7 @@ patience = 4
# num_workers: number of workers to use when loading data with multiprocessing
num_workers = 4
# device: name of device to run model on, one of "cuda", "cpu"
-device = "cuda"
+
# dataset_path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it
# dataset.params = parameters used for datasets
@@ -73,3 +73,10 @@ lr = 0.001
[vak.train.model.TweetyNet.network]
# hidden_size: the number of elements in the hidden state in the recurrent layer of the network
hidden_size = 256
+
+# this sub-table configures the `lightning.pytorch.Trainer`
+[vak.train.trainer]
+# setting to 'gpu' means "train models on 'gpu' (not 'cpu')"
+accelerator = "gpu"
+# use the first GPU (numbering starts from 0)
+devices = [0]
diff --git a/pyproject.toml b/pyproject.toml
index f0895a996..61efe6fd7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -28,7 +28,7 @@ dependencies = [
"dask[dataframe] >=2.10.1",
"evfuncs >=0.3.4",
"joblib >=0.14.1",
- "pytorch-lightning >=2.0.7",
+ "lightning >=2.0.7",
"matplotlib >=3.3.3",
"numpy >=1.18.1",
"pynndescent >=0.5.10",
diff --git a/src/vak/cli/eval.py b/src/vak/cli/eval.py
index 29bee65a5..a80245464 100644
--- a/src/vak/cli/eval.py
+++ b/src/vak/cli/eval.py
@@ -52,12 +52,12 @@ def eval(toml_path: str | pathlib.Path) -> None:
eval_module.eval(
model_config=cfg.eval.model.asdict(),
dataset_config=cfg.eval.dataset.asdict(),
+ trainer_config=cfg.eval.trainer.asdict(),
checkpoint_path=cfg.eval.checkpoint_path,
labelmap_path=cfg.eval.labelmap_path,
output_dir=cfg.eval.output_dir,
num_workers=cfg.eval.num_workers,
batch_size=cfg.eval.batch_size,
spect_scaler_path=cfg.eval.spect_scaler_path,
- device=cfg.eval.device,
post_tfm_kwargs=cfg.eval.post_tfm_kwargs,
)
diff --git a/src/vak/cli/learncurve.py b/src/vak/cli/learncurve.py
index 2decc5cd8..c8f407fe0 100644
--- a/src/vak/cli/learncurve.py
+++ b/src/vak/cli/learncurve.py
@@ -55,6 +55,7 @@ def learning_curve(toml_path):
learncurve.learning_curve(
model_config=cfg.learncurve.model.asdict(),
dataset_config=cfg.learncurve.dataset.asdict(),
+ trainer_config=cfg.learncurve.trainer.asdict(),
batch_size=cfg.learncurve.batch_size,
num_epochs=cfg.learncurve.num_epochs,
num_workers=cfg.learncurve.num_workers,
@@ -65,5 +66,4 @@ def learning_curve(toml_path):
val_step=cfg.learncurve.val_step,
ckpt_step=cfg.learncurve.ckpt_step,
patience=cfg.learncurve.patience,
- device=cfg.learncurve.device,
)
diff --git a/src/vak/cli/predict.py b/src/vak/cli/predict.py
index 01c0e2612..474b2b8ca 100644
--- a/src/vak/cli/predict.py
+++ b/src/vak/cli/predict.py
@@ -45,12 +45,12 @@ def predict(toml_path):
predict_module.predict(
model_config=cfg.predict.model.asdict(),
dataset_config=cfg.predict.dataset.asdict(),
+ trainer_config=cfg.predict.trainer.asdict(),
checkpoint_path=cfg.predict.checkpoint_path,
labelmap_path=cfg.predict.labelmap_path,
num_workers=cfg.predict.num_workers,
timebins_key=cfg.prep.spect_params.timebins_key,
spect_scaler_path=cfg.predict.spect_scaler_path,
- device=cfg.predict.device,
annot_csv_filename=cfg.predict.annot_csv_filename,
output_dir=cfg.predict.output_dir,
min_segment_dur=cfg.predict.min_segment_dur,
diff --git a/src/vak/cli/train.py b/src/vak/cli/train.py
index c63096ca2..88c2fb6d9 100644
--- a/src/vak/cli/train.py
+++ b/src/vak/cli/train.py
@@ -55,6 +55,7 @@ def train(toml_path):
train_module.train(
model_config=cfg.train.model.asdict(),
dataset_config=cfg.train.dataset.asdict(),
+ trainer_config=cfg.train.trainer.asdict(),
batch_size=cfg.train.batch_size,
num_epochs=cfg.train.num_epochs,
num_workers=cfg.train.num_workers,
@@ -66,5 +67,4 @@ def train(toml_path):
val_step=cfg.train.val_step,
ckpt_step=cfg.train.ckpt_step,
patience=cfg.train.patience,
- device=cfg.train.device,
)
diff --git a/src/vak/common/__init__.py b/src/vak/common/__init__.py
index 777bd9afb..11c4e330d 100644
--- a/src/vak/common/__init__.py
+++ b/src/vak/common/__init__.py
@@ -9,10 +9,10 @@
"""
from . import (
+ accelerator,
annotation,
constants,
converters,
- device,
files,
labels,
learncurve,
@@ -30,7 +30,7 @@
"annotation",
"constants",
"converters",
- "device",
+ "accelerator",
"files",
"labels",
"learncurve",
diff --git a/src/vak/common/accelerator.py b/src/vak/common/accelerator.py
new file mode 100644
index 000000000..3f3e3eee2
--- /dev/null
+++ b/src/vak/common/accelerator.py
@@ -0,0 +1,16 @@
+import torch
+
+
+def get_default() -> str:
+ """Get default `accelerator` for :class:`lightning.pytorch.Trainer`.
+
+ Returns
+ -------
+ accelerator : str
+ Will be ``'gpu'`` if :func:`torch.cuda.is_available`
+ is ``True``, and ``'cpu'`` if not.
+ """
+ if torch.cuda.is_available():
+ return "gpu"
+ else:
+ return "cpu"
diff --git a/src/vak/common/device.py b/src/vak/common/device.py
deleted file mode 100644
index 5e7d506e5..000000000
--- a/src/vak/common/device.py
+++ /dev/null
@@ -1,16 +0,0 @@
-import torch
-
-
-def get_default():
- """get default device for torch.
-
- Returns
- -------
- device : str
- 'cuda' if torch.cuda.is_available() is True,
- and returns 'cpu' otherwise.
- """
- if torch.cuda.is_available():
- return "cuda"
- else:
- return "cpu"
diff --git a/src/vak/common/trainer.py b/src/vak/common/trainer.py
index 7a6400d97..cd12d03d4 100644
--- a/src/vak/common/trainer.py
+++ b/src/vak/common/trainer.py
@@ -2,7 +2,7 @@
import pathlib
-import pytorch_lightning as lightning
+import lightning
def get_default_train_callbacks(
@@ -10,7 +10,7 @@ def get_default_train_callbacks(
ckpt_step: int,
patience: int,
):
- ckpt_callback = lightning.callbacks.ModelCheckpoint(
+ ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint(
dirpath=ckpt_root,
filename="checkpoint",
every_n_train_steps=ckpt_step,
@@ -20,7 +20,7 @@ def get_default_train_callbacks(
ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint"
ckpt_callback.FILE_EXTENSION = ".pt"
- val_ckpt_callback = lightning.callbacks.ModelCheckpoint(
+ val_ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint(
monitor="val_acc",
dirpath=ckpt_root,
save_top_k=1,
@@ -31,7 +31,7 @@ def get_default_train_callbacks(
)
val_ckpt_callback.FILE_EXTENSION = ".pt"
- early_stopping = lightning.callbacks.EarlyStopping(
+ early_stopping = lightning.pytorch.callbacks.EarlyStopping(
mode="max",
monitor="val_acc",
patience=patience,
@@ -42,33 +42,47 @@ def get_default_train_callbacks(
def get_default_trainer(
+ accelerator: str,
+ devices: int | list[int],
max_steps: int,
log_save_dir: str | pathlib.Path,
val_step: int,
default_callback_kwargs: dict | None = None,
- device: str = "cuda",
-) -> lightning.Trainer:
- """Returns an instance of ``lightning.Trainer``
+) -> lightning.pytorch.Trainer:
+ """Returns an instance of :class:`lightning.pytorch.Trainer`
with a default set of callbacks.
- Used by ``vak.core`` functions."""
+
+ Used by :func:`vak.train.frame_classification`.
+ The default set of callbacks is provided by
+ :func:`get_default_train_callbacks`.
+
+ Parameters
+ ----------
+ accelerator : str
+ devices : int, list of int
+ max_steps : int
+ log_save_dir : str, pathlib.Path
+ val_step : int
+ default_callback_kwargs : dict, optional
+
+ Returns
+ -------
+ trainer : lightning.pytorch.Trainer
+
+ """
if default_callback_kwargs:
callbacks = get_default_train_callbacks(**default_callback_kwargs)
else:
callbacks = None
- # TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691
- if device == "cuda":
- accelerator = "gpu"
- else:
- accelerator = "auto"
-
- logger = lightning.loggers.TensorBoardLogger(save_dir=log_save_dir)
+ logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=log_save_dir)
- trainer = lightning.Trainer(
+ trainer = lightning.pytorch.Trainer(
+ accelerator=accelerator,
+ devices=devices,
callbacks=callbacks,
val_check_interval=val_step,
max_steps=max_steps,
- accelerator=accelerator,
logger=logger,
)
return trainer
diff --git a/src/vak/config/__init__.py b/src/vak/config/__init__.py
index 056c0ef12..e0184c522 100644
--- a/src/vak/config/__init__.py
+++ b/src/vak/config/__init__.py
@@ -11,6 +11,7 @@
prep,
spect_params,
train,
+ trainer,
validators,
)
from .config import Config
@@ -22,6 +23,8 @@
from .prep import PrepConfig
from .spect_params import SpectParamsConfig
from .train import TrainConfig
+from .trainer import TrainerConfig
+
__all__ = [
"config",
@@ -34,6 +37,7 @@
"prep",
"spect_params",
"train",
+ "trainer",
"validators",
"Config",
"DatasetConfig",
@@ -44,4 +48,5 @@
"PrepConfig",
"SpectParamsConfig",
"TrainConfig",
+ "TrainerConfig",
]
diff --git a/src/vak/config/eval.py b/src/vak/config/eval.py
index 3012da649..e524efdca 100644
--- a/src/vak/config/eval.py
+++ b/src/vak/config/eval.py
@@ -7,10 +7,10 @@
from attrs import converters, define, field, validators
from attrs.validators import instance_of
-from ..common import device
from ..common.converters import expanded_user_path
from .dataset import DatasetConfig
from .model import ModelConfig
+from .trainer import TrainerConfig
def convert_post_tfm_kwargs(post_tfm_kwargs: dict) -> dict:
@@ -77,6 +77,7 @@ def are_valid_post_tfm_kwargs(instance, attribute, value):
"dataset",
"output_dir",
"model",
+ "trainer",
)
@@ -86,29 +87,29 @@ class EvalConfig:
Attributes
----------
- dataset : vak.config.DatasetConfig
- The dataset to use: the path to it,
- and optionally a path to a file representing splits,
- and the name, if it is a built-in dataset.
- Must be an instance of :class:`vak.config.DatasetConfig`.
checkpoint_path : str
path to directory with checkpoint files saved by Torch, to reload model
output_dir : str
Path to location where .csv files with evaluation metrics should be saved.
- labelmap_path : str
- path to 'labelmap.json' file.
model : vak.config.ModelConfig
The model to use: its name,
and the parameters to configure it.
Must be an instance of :class:`vak.config.ModelConfig`
batch_size : int
number of samples per batch presented to models during training.
+ dataset : vak.config.DatasetConfig
+ The dataset to use: the path to it,
+ and optionally a path to a file representing splits,
+ and the name, if it is a built-in dataset.
+ Must be an instance of :class:`vak.config.DatasetConfig`.
+ trainer : vak.config.TrainerConfig
+ Configuration for :class:`lightning.pytorch.Trainer`.
+ Must be an instance of :class:`vak.config.TrainerConfig`.
num_workers : int
Number of processes to use for parallel loading of data.
Argument to torch.DataLoader. Default is 2.
- device : str
- Device on which to work with model + data.
- Defaults to 'cuda' if torch.cuda.is_available is True.
+ labelmap_path : str
+ path to 'labelmap.json' file.
spect_scaler_path : str
path to a saved SpectScaler object used to normalize spectrograms.
If spectrograms were normalized and this is not provided, will give
@@ -140,6 +141,9 @@ class EvalConfig:
dataset: DatasetConfig = field(
validator=instance_of(DatasetConfig),
)
+ trainer: TrainerConfig = field(
+ validator=instance_of(TrainerConfig),
+ )
# "optional" but actually required for frame classification models
# TODO: check model family in __post_init__ and raise ValueError if labelmap
@@ -161,7 +165,6 @@ class EvalConfig:
# optional, data loader
num_workers = field(validator=instance_of(int), default=2)
- device = field(validator=instance_of(str), default=device.get_default())
@classmethod
def from_config_dict(cls, config_dict: dict) -> EvalConfig:
@@ -186,4 +189,5 @@ def from_config_dict(cls, config_dict: dict) -> EvalConfig:
config_dict["model"] = ModelConfig.from_config_dict(
config_dict["model"]
)
+ config_dict["trainer"] = TrainerConfig(**config_dict["trainer"])
return cls(**config_dict)
diff --git a/src/vak/config/learncurve.py b/src/vak/config/learncurve.py
index fdf2b883a..bb564db72 100644
--- a/src/vak/config/learncurve.py
+++ b/src/vak/config/learncurve.py
@@ -8,8 +8,9 @@
from .eval import are_valid_post_tfm_kwargs, convert_post_tfm_kwargs
from .model import ModelConfig
from .train import TrainConfig
+from .trainer import TrainerConfig
-REQUIRED_KEYS = ("dataset", "model", "root_results_dir")
+REQUIRED_KEYS = ("dataset", "model", "root_results_dir", "trainer",)
@define
@@ -22,31 +23,44 @@ class LearncurveConfig(TrainConfig):
The model to use: its name,
and the parameters to configure it.
Must be an instance of :class:`vak.config.ModelConfig`
+ num_epochs : int
+ number of training epochs. One epoch = one iteration through the entire
+ training set.
+ batch_size : int
+ number of samples per batch presented to models during training.
+ root_results_dir : str
+ directory in which results will be created.
+ The vak.cli.train function will create
+ a subdirectory in this directory each time it runs.
dataset : vak.config.DatasetConfig
The dataset to use: the path to it,
and optionally a path to a file representing splits,
and the name, if it is a built-in dataset.
Must be an instance of :class:`vak.config.DatasetConfig`.
- num_epochs : int
- number of training epochs. One epoch = one iteration through the entire
- training set.
+ trainer : vak.config.TrainerConfig
+ Configuration for :class:`lightning.pytorch.Trainer`.
+ Must be an instance of :class:`vak.config.TrainerConfig`.
+ num_workers : int
+ Number of processes to use for parallel loading of data.
+ Argument to torch.DataLoader.
+ shuffle: bool
+ if True, shuffle training data before each epoch. Default is True.
normalize_spectrograms : bool
if True, use spect.utils.data.SpectScaler to normalize the spectrograms.
Normalization is done by subtracting off the mean for each frequency bin
of the training set and then dividing by the std for that frequency bin.
This same normalization is then applied to validation + test data.
+ val_step : int
+ Step on which to estimate accuracy using validation set.
+ If val_step is n, then validation is carried out every time
+ the global step / n is a whole number, i.e., when val_step modulo the global step is 0.
+ Default is None, in which case no validation is done.
ckpt_step : int
step/epoch at which to save to checkpoint file.
Default is None, in which case checkpoint is only saved at the last epoch.
patience : int
number of epochs to wait without the error dropping before stopping the
training. Default is None, in which case training continues for num_epochs
- save_only_single_checkpoint_file : bool
- if True, save only one checkpoint file instead of separate files every time
- we save. Default is True.
- use_train_subsets_from_previous_run : bool
- if True, use training subsets saved in a previous run. Default is False.
- Requires setting previous_run_path option in config.toml file.
post_tfm_kwargs : dict
Keyword arguments to post-processing transform.
If None, then no additional clean-up is applied
@@ -61,7 +75,6 @@ class LearncurveConfig(TrainConfig):
See the docstring of the transform for more details on
these arguments and how they work.
"""
-
post_tfm_kwargs = field(
validator=validators.optional(are_valid_post_tfm_kwargs),
converter=converters.optional(convert_post_tfm_kwargs),
@@ -71,18 +84,18 @@ class LearncurveConfig(TrainConfig):
# we over-ride this method from TrainConfig mainly so the docstring is correct.
# TODO: can we do this by just over-writing `__doc__` for the method on this class?
@classmethod
- def from_config_dict(cls, config_dict: dict) -> "TrainConfig":
+ def from_config_dict(cls, config_dict: dict) -> LearncurveConfig:
"""Return :class:`LearncurveConfig` instance from a :class:`dict`.
The :class:`dict` passed in should be the one found
by loading a valid configuration toml file with
:func:`vak.config.parse.from_toml_path`,
- and then using key ``prep``,
- i.e., ``LearncurveConfig.from_config_dict(config_dict['train'])``."""
+ and then using key ``learncurve``,
+ i.e., ``LearncurveConfig.from_config_dict(config_dict['learncurve'])``."""
for required_key in REQUIRED_KEYS:
if required_key not in config_dict:
raise KeyError(
- "The `[vak.train]` table in a configuration file requires "
+ "The `[vak.learncurve]` table in a configuration file requires "
f"the option '{required_key}', but it was not found "
"when loading the configuration file into a Python dictionary. "
"Please check that the configuration file is formatted correctly."
@@ -93,4 +106,5 @@ def from_config_dict(cls, config_dict: dict) -> "TrainConfig":
config_dict["dataset"] = DatasetConfig.from_config_dict(
config_dict["dataset"]
)
+ config_dict["trainer"] = TrainerConfig(**config_dict["trainer"])
return cls(**config_dict)
diff --git a/src/vak/config/predict.py b/src/vak/config/predict.py
index 8803d9317..ee63ee093 100644
--- a/src/vak/config/predict.py
+++ b/src/vak/config/predict.py
@@ -9,15 +9,17 @@
from attr.validators import instance_of
from attrs import define, field
-from ..common import device
from ..common.converters import expanded_user_path
from .dataset import DatasetConfig
from .model import ModelConfig
+from .trainer import TrainerConfig
+
REQUIRED_KEYS = (
"checkpoint_path",
"dataset",
"model",
+ "trainer",
)
@@ -27,11 +29,6 @@ class PredictConfig:
Attributes
----------
- dataset : vak.config.DatasetConfig
- The dataset to use: the path to it,
- and optionally a path to a file representing splits,
- and the name, if it is a built-in dataset.
- Must be an instance of :class:`vak.config.DatasetConfig`.
checkpoint_path : str
path to directory with checkpoint files saved by Torch, to reload model
labelmap_path : str
@@ -40,49 +37,54 @@ class PredictConfig:
The model to use: its name,
and the parameters to configure it.
Must be an instance of :class:`vak.config.ModelConfig`
- batch_size : int
- number of samples per batch presented to models during training.
- num_workers : int
- Number of processes to use for parallel loading of data.
- Argument to torch.DataLoader. Default is 2.
- device : str
- Device on which to work with model + data.
- Defaults to 'cuda' if torch.cuda.is_available is True.
- spect_scaler_path : str
- path to a saved SpectScaler object used to normalize spectrograms.
- If spectrograms were normalized and this is not provided, will give
- incorrect results.
- annot_csv_filename : str
- name of .csv file containing predicted annotations.
- Default is None, in which case the name of the dataset .csv
- is used, with '.annot.csv' appended to it.
- output_dir : str
- path to location where .csv containing predicted annotation
- should be saved. Defaults to current working directory.
- min_segment_dur : float
- minimum duration of segment, in seconds. If specified, then
- any segment with a duration less than min_segment_dur is
- removed from lbl_tb. Default is None, in which case no
- segments are removed.
- majority_vote : bool
- if True, transform segments containing multiple labels
- into segments with a single label by taking a "majority vote",
- i.e. assign all time bins in the segment the most frequently
- occurring label in the segment. This transform can only be
- applied if the labelmap contains an 'unlabeled' label,
- because unlabeled segments makes it possible to identify
- the labeled segments. Default is False.
+ batch_size : int
+ number of samples per batch presented to models during training.
+ dataset : vak.config.DatasetConfig
+ The dataset to use: the path to it,
+ and optionally a path to a file representing splits,
+ and the name, if it is a built-in dataset.
+ Must be an instance of :class:`vak.config.DatasetConfig`.
+ trainer : vak.config.TrainerConfig
+ Configuration for :class:`lightning.pytorch.Trainer`.
+ Must be an instance of :class:`vak.config.TrainerConfig`.
+ num_workers : int
+ Number of processes to use for parallel loading of data.
+ Argument to torch.DataLoader. Default is 2.
+ spect_scaler_path : str
+ path to a saved SpectScaler object used to normalize spectrograms.
+ If spectrograms were normalized and this is not provided, will give
+ incorrect results.
+ annot_csv_filename : str
+ name of .csv file containing predicted annotations.
+ Default is None, in which case the name of the dataset .csv
+ is used, with '.annot.csv' appended to it.
+ output_dir : str
+ path to location where .csv containing predicted annotation
+ should be saved. Defaults to current working directory.
+ min_segment_dur : float
+ minimum duration of segment, in seconds. If specified, then
+ any segment with a duration less than min_segment_dur is
+ removed from lbl_tb. Default is None, in which case no
+ segments are removed.
+ majority_vote : bool
+ if True, transform segments containing multiple labels
+ into segments with a single label by taking a "majority vote",
+ i.e. assign all time bins in the segment the most frequently
+ occurring label in the segment. This transform can only be
+ applied if the labelmap contains an 'unlabeled' label,
+ because unlabeled segments makes it possible to identify
+ the labeled segments. Default is False.
save_net_outputs : bool
- if True, save 'raw' outputs of neural networks
- before they are converted to annotations. Default is False.
- Typically the output will be "logits"
- to which a softmax transform might be applied.
- For each item in the dataset--each row in the `dataset_path` .csv--
- the output will be saved in a separate file in `output_dir`,
- with the extension `{MODEL_NAME}.output.npz`. E.g., if the input is a
- spectrogram with `spect_path` filename `gy6or6_032312_081416.npz`,
- and the network is `TweetyNet`, then the net output file
- will be `gy6or6_032312_081416.tweetynet.output.npz`.
+ If True, save 'raw' outputs of neural networks
+ before they are converted to annotations. Default is False.
+ Typically the output will be "logits"
+ to which a softmax transform might be applied.
+ For each item in the dataset--each row in the `dataset_path` .csv--
+ the output will be saved in a separate file in `output_dir`,
+ with the extension `{MODEL_NAME}.output.npz`. E.g., if the input is a
+ spectrogram with `spect_path` filename `gy6or6_032312_081416.npz`,
+ and the network is `TweetyNet`, then the net output file
+ will be `gy6or6_032312_081416.tweetynet.output.npz`.
"""
# required, external files
@@ -97,6 +99,9 @@ class PredictConfig:
dataset: DatasetConfig = field(
validator=instance_of(DatasetConfig),
)
+ trainer: TrainerConfig = field(
+ validator=instance_of(TrainerConfig),
+ )
# optional, transform
spect_scaler_path = field(
@@ -106,7 +111,6 @@ class PredictConfig:
# optional, data loader
num_workers = field(validator=instance_of(int), default=2)
- device = field(validator=instance_of(str), default=device.get_default())
annot_csv_filename = field(
validator=validators.optional(instance_of(str)), default=None
@@ -133,7 +137,7 @@ def from_config_dict(cls, config_dict: dict) -> PredictConfig:
for required_key in REQUIRED_KEYS:
if required_key not in config_dict:
raise KeyError(
- "The `[vak.eval]` table in a configuration file requires "
+ "The `[vak.predict]` table in a configuration file requires "
f"the option '{required_key}', but it was not found "
"when loading the configuration file into a Python dictionary. "
"Please check that the configuration file is formatted correctly."
@@ -144,4 +148,5 @@ def from_config_dict(cls, config_dict: dict) -> PredictConfig:
config_dict["model"] = ModelConfig.from_config_dict(
config_dict["model"]
)
+ config_dict["trainer"] = TrainerConfig(**config_dict["trainer"])
return cls(**config_dict)
diff --git a/src/vak/config/train.py b/src/vak/config/train.py
index 4c997a8c1..08b8a7d49 100644
--- a/src/vak/config/train.py
+++ b/src/vak/config/train.py
@@ -3,12 +3,18 @@
from attrs import converters, define, field, validators
from attrs.validators import instance_of
-from ..common import device
from ..common.converters import bool_from_str, expanded_user_path
from .dataset import DatasetConfig
from .model import ModelConfig
+from .trainer import TrainerConfig
-REQUIRED_KEYS = ("dataset", "model", "root_results_dir")
+
+REQUIRED_KEYS = (
+ "dataset",
+ "model",
+ "root_results_dir",
+ "trainer",
+)
@define
@@ -21,11 +27,6 @@ class TrainConfig:
The model to use: its name,
and the parameters to configure it.
Must be an instance of :class:`vak.config.ModelConfig`
- dataset : vak.config.DatasetConfig
- The dataset to use: the path to it,
- and optionally a path to a file representing splits,
- and the name, if it is a built-in dataset.
- Must be an instance of :class:`vak.config.DatasetConfig`.
num_epochs : int
number of training epochs. One epoch = one iteration through the entire
training set.
@@ -35,12 +36,17 @@ class TrainConfig:
directory in which results will be created.
The vak.cli.train function will create
a subdirectory in this directory each time it runs.
+ dataset : vak.config.DatasetConfig
+ The dataset to use: the path to it,
+ and optionally a path to a file representing splits,
+ and the name, if it is a built-in dataset.
+ Must be an instance of :class:`vak.config.DatasetConfig`.
+ trainer : vak.config.TrainerConfig
+ Configuration for :class:`lightning.pytorch.Trainer`.
+ Must be an instance of :class:`vak.config.TrainerConfig`.
num_workers : int
Number of processes to use for parallel loading of data.
Argument to torch.DataLoader.
- device : str
- Device on which to work with model + data.
- Defaults to 'cuda' if torch.cuda.is_available is True.
shuffle: bool
if True, shuffle training data before each epoch. Default is True.
normalize_spectrograms : bool
@@ -81,6 +87,9 @@ class TrainConfig:
dataset: DatasetConfig = field(
validator=instance_of(DatasetConfig),
)
+ trainer: TrainerConfig = field(
+ validator=instance_of(TrainerConfig),
+ )
results_dirname = field(
converter=converters.optional(expanded_user_path),
@@ -94,7 +103,6 @@ class TrainConfig:
)
num_workers = field(validator=instance_of(int), default=2)
- device = field(validator=instance_of(str), default=device.get_default())
shuffle = field(
converter=bool_from_str, validator=instance_of(bool), default=True
)
@@ -130,7 +138,7 @@ def from_config_dict(cls, config_dict: dict) -> "TrainConfig":
The :class:`dict` passed in should be the one found
by loading a valid configuration toml file with
:func:`vak.config.parse.from_toml_path`,
- and then using key ``prep``,
+ and then using key ``train``,
i.e., ``TrainConfig.from_config_dict(config_dict['train'])``."""
for required_key in REQUIRED_KEYS:
if required_key not in config_dict:
@@ -146,4 +154,5 @@ def from_config_dict(cls, config_dict: dict) -> "TrainConfig":
config_dict["dataset"] = DatasetConfig.from_config_dict(
config_dict["dataset"]
)
+ config_dict["trainer"] = TrainerConfig(**config_dict["trainer"])
return cls(**config_dict)
diff --git a/src/vak/config/trainer.py b/src/vak/config/trainer.py
new file mode 100644
index 000000000..b2e751905
--- /dev/null
+++ b/src/vak/config/trainer.py
@@ -0,0 +1,117 @@
+
+from __future__ import annotations
+
+from attrs import asdict, define, field, validators
+
+from .. import common
+
+
+def is_valid_accelerator(instance, attribute, value):
+ """Check if ``accelerator`` is valid"""
+ if value == "auto":
+ raise ValueError(
+ "Using the 'auto' value for the `lightning.pytorch.Trainer` parameter `accelerator` currently "
+ "breaks functionality for the command-line interface of `vak`. "
+ "Please see this issue: https://github.com/vocalpy/vak/issues/691"
+ "If you need to use the 'auto' mode of `lightning.pytorch.Trainer`, please use `vak` directly in a script."
+ )
+ elif value in ("cpu", "gpu", "tpu", "ipu"):
+ return
+ else:
+ raise ValueError(
+ f"Invalid value for 'accelerator' key in 'trainer' table of configuration file: {value}. "
+ "Value must be one of: {\"cpu\", \"gpu\", \"tpu\", \"ipu\"}"
+ )
+
+
+def is_valid_devices(instance, attribute, value):
+ """Check if ``devices`` is valid"""
+ if not (
+ (isinstance(value, int)) or
+ (isinstance(value, list) and all([isinstance(el, int) for el in value]))
+ ):
+ raise ValueError(
+ "Invalid value for 'devices' key in 'trainer' table of configuration file: {value}"
+ )
+
+
+@define
+class TrainerConfig:
+ """Class that represents ``trainer`` sub-table
+ in a toml configuration file.
+
+ Used to configure :class:`lightning.pytorch.Trainer`.
+
+ Attributes
+ ----------
+ accelerator : str
+ Value for the `accelerator` argument to :class:`lightning.pytorch.Trainer`.
+ Default is the return value of :func:`vak.common.accelerator.get_default`.
+ devices: int, list of int
+ Number of devices (int) or exact device(s) (list of int) to use.
+
+ Notes
+ -----
+ Using the 'auto' value for the `lightning.pytorch.Trainer` parameter `accelerator` currently
+ breaks functionality for the command-line interface of `vak`.
+ Please see this issue: https://github.com/vocalpy/vak/issues/691
+ If you need to use the 'auto' mode of `lightning.pytorch.Trainer`, please use `vak` directly in a script.
+
+ Likewise, setting a value for the `lightning.pytorch.Trainer` parameter `devices` that is not either 1
+ (meaning \"use a single GPU\") or a list with a single number (meaning \"use this exact GPU\") currently
+ breaks functionality for the command-line interface of `vak`.
+ Please see this issue: https://github.com/vocalpy/vak/issues/691
+ If you need to use multiple GPUs, please use `vak` directly in a script.
+ """
+ accelerator: str = field(
+ validator=is_valid_accelerator,
+ default=common.accelerator.get_default()
+ )
+ devices: int | list[int] = field(
+ validator=validators.optional(is_valid_devices),
+ # for devices, we need to look at accelerator in post-init to determine default
+ default=None,
+ )
+
+ def __attrs_post_init__(self):
+ # set default self.devices *before* we validate,
+ # so that we don't throw error because of the default None
+ # that we need to change here depending on the value of self.accelerator
+ if self.devices is None:
+ if self.accelerator == "cpu":
+ # ~"use all available"
+ self.devices = 1
+ elif self.accelerator in ("gpu", "tpu", "ipu"):
+ # we can only use a single device, assume there's only one
+ self.devices = [0]
+
+ if self.accelerator in ("gpu", "tpu", "ipu"):
+ if not (
+ (isinstance(self.devices, int) and self.devices == 1) or
+ (isinstance(self.devices, list) and len(self.devices) == 1 and all([isinstance(el, int) for el in self.devices]))
+ ):
+ raise ValueError(
+ "Setting a value for the `lightning.pytorch.Trainer` parameter `devices` that is not either 1 "
+ "(meaning \"use a single GPU\") or a list with a single number (meaning \"use this exact GPU\") currently "
+ "breaks functionality for the command-line interface of `vak`. "
+ "Please see this issue: https://github.com/vocalpy/vak/issues/691"
+ "If you need to use multiple GPUs, please use `vak` directly in a script."
+ )
+ elif self.accelerator == "cpu":
+ if isinstance(self.devices, list):
+ raise ValueError(
+ f"Value for `devices` cannot be a list when `accelerator` is `cpu`. Value was: {self.devices}\n"
+ "When `accelerator` is `cpu`, please set `devices` to 1 or -1 (which are equivalent)."
+ )
+ if self.devices < 1:
+ raise ValueError(
+ f"When value for 'accelerator' is 'cpu', value for `devices` should be an int > 0, but was: {self.devices}"
+ )
+
+ def asdict(self):
+ """Convert this :class:`TrainerConfig` instance
+ to a :class:`dict` that can be passed
+ into functions that take a ``trainer_config`` argument,
+ like :func:`vak.train` and :func:`vak.predict`.
+ """
+ return asdict(self)
diff --git a/src/vak/config/valid-version-1.0.toml b/src/vak/config/valid-version-1.1.toml
similarity index 93%
rename from src/vak/config/valid-version-1.0.toml
rename to src/vak/config/valid-version-1.1.toml
index ded6aa6ae..95aa50372 100644
--- a/src/vak/config/valid-version-1.0.toml
+++ b/src/vak/config/valid-version-1.1.toml
@@ -36,7 +36,6 @@ audio_path_key = 'audio_path'
[vak.train]
root_results_dir = './tests/test_data/results/train'
num_workers = 4
-device = 'cuda'
batch_size = 11
num_epochs = 2
normalize_spectrograms = true
@@ -56,13 +55,16 @@ params = {window_size = 2000}
[vak.train.model.TweetyNet]
+[vak.train.trainer]
+accelerator = "gpu"
+devices = [0]
+
[vak.eval]
checkpoint_path = '/home/user/results_181014_194418/TweetyNet/checkpoints/'
labelmap_path = '/home/user/results_181014_194418/labelmap.json'
output_dir = './tests/test_data/prep/learncurve'
batch_size = 11
num_workers = 4
-device = 'cuda'
spect_scaler_path = '/home/user/results_181014_194418/spect_scaler'
post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01}
@@ -73,6 +75,10 @@ splits_path = 'tests/test_data/prep/train/032312_prep_191224_225912.splits.json'
[vak.eval.model.TweetyNet]
+[vak.eval.trainer]
+accelerator = "gpu"
+devices = [0]
+
[vak.learncurve]
root_results_dir = './tests/test_data/results/learncurve'
batch_size = 11
@@ -85,7 +91,6 @@ patience = 4
results_dir_made_by_main_script = '/some/path/to/learncurve/'
post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01}
num_workers = 4
-device = 'cuda'
[vak.learncurve.dataset]
name = 'IntlDistributedSongbirdConsortiumPack'
@@ -95,6 +100,10 @@ params = {window_size = 2000}
[vak.learncurve.model.TweetyNet]
+[vak.learncurve.trainer]
+accelerator = "gpu"
+devices = [0]
+
[vak.predict]
checkpoint_path = '/home/user/results_181014_194418/TweetyNet/checkpoints/'
labelmap_path = '/home/user/results_181014_194418/labelmap.json'
@@ -102,7 +111,6 @@ annot_csv_filename = '032312_prep_191224_225910.annot.csv'
output_dir = './tests/test_data/prep/learncurve'
batch_size = 11
num_workers = 4
-device = 'cuda'
spect_scaler_path = '/home/user/results_181014_194418/spect_scaler'
min_segment_dur = 0.004
majority_vote = false
@@ -114,4 +122,8 @@ path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv'
splits_path = 'tests/test_data/prep/train/032312_prep_191224_225912.splits.json'
params = {window_size = 2000}
-[vak.predict.model.TweetyNet]
\ No newline at end of file
+[vak.predict.model.TweetyNet]
+
+[vak.predict.trainer]
+accelerator = "gpu"
+devices = [0]
diff --git a/src/vak/config/validators.py b/src/vak/config/validators.py
index 628656a51..da396b8d3 100644
--- a/src/vak/config/validators.py
+++ b/src/vak/config/validators.py
@@ -59,7 +59,7 @@ def is_spect_format(instance, attribute, value):
CONFIG_DIR = pathlib.Path(__file__).parent
-VALID_TOML_PATH = CONFIG_DIR.joinpath("valid-version-1.0.toml")
+VALID_TOML_PATH = CONFIG_DIR.joinpath("valid-version-1.1.toml")
with VALID_TOML_PATH.open("r") as fp:
VALID_DICT = tomlkit.load(fp)["vak"]
VALID_TOP_LEVEL_TABLES = list(VALID_DICT.keys())
diff --git a/src/vak/eval/eval_.py b/src/vak/eval/eval_.py
index fa1209f1d..665ec7d0d 100644
--- a/src/vak/eval/eval_.py
+++ b/src/vak/eval/eval_.py
@@ -16,6 +16,7 @@
def eval(
model_config: dict,
dataset_config: dict,
+ trainer_config: dict,
checkpoint_path: str | pathlib.Path,
output_dir: str | pathlib.Path,
num_workers: int,
@@ -36,6 +37,9 @@ def eval(
dataset_config: dict
Dataset configuration in a :class:`dict`.
Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`.
+ trainer_config: dict
+ Configuration for :class:`lightning.pytorch.Trainer`.
+ Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`.
checkpoint_path : str, pathlib.Path
path to directory with checkpoint files saved by Torch, to reload model
output_dir : str, pathlib.Path
@@ -111,25 +115,25 @@ def eval(
eval_frame_classification_model(
model_config=model_config,
dataset_config=dataset_config,
+ trainer_config=trainer_config,
checkpoint_path=checkpoint_path,
labelmap_path=labelmap_path,
output_dir=output_dir,
num_workers=num_workers,
split=split,
spect_scaler_path=spect_scaler_path,
- device=device,
post_tfm_kwargs=post_tfm_kwargs,
)
elif model_family == "ParametricUMAPModel":
eval_parametric_umap_model(
model_config=model_config,
dataset_config=dataset_config,
+ trainer_config=trainer_config,
checkpoint_path=checkpoint_path,
output_dir=output_dir,
batch_size=batch_size,
num_workers=num_workers,
split=split,
- device=device,
)
else:
raise ValueError(f"Model family not recognized: {model_family}")
diff --git a/src/vak/eval/frame_classification.py b/src/vak/eval/frame_classification.py
index 9757287e8..60be44b3e 100644
--- a/src/vak/eval/frame_classification.py
+++ b/src/vak/eval/frame_classification.py
@@ -10,7 +10,7 @@
import joblib
import pandas as pd
-import pytorch_lightning as lightning
+import lightning
import torch.utils.data
from .. import datasets, models, transforms
@@ -23,6 +23,7 @@
def eval_frame_classification_model(
model_config: dict,
dataset_config: dict,
+ trainer_config: dict,
checkpoint_path: str | pathlib.Path,
labelmap_path: str | pathlib.Path,
output_dir: str | pathlib.Path,
@@ -30,7 +31,6 @@ def eval_frame_classification_model(
split: str = "test",
spect_scaler_path: str | pathlib.Path = None,
post_tfm_kwargs: dict | None = None,
- device: str | None = None,
) -> None:
"""Evaluate a trained model.
@@ -42,6 +42,9 @@ def eval_frame_classification_model(
dataset_config: dict
Dataset configuration in a :class:`dict`.
Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`.
+ trainer_config: dict
+ Configuration for :class:`lightning.pytorch.Trainer`.
+ Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`.
checkpoint_path : str, pathlib.Path
Path to directory with checkpoint files saved by Torch, to reload model
output_dir : str, pathlib.Path
@@ -71,18 +74,15 @@ def eval_frame_classification_model(
a float value for ``min_segment_dur``.
See the docstring of the transform for more details on
these arguments and how they work.
- device : str
- Device on which to work with model + data.
- Defaults to 'cuda' if torch.cuda.is_available is True.
Notes
-----
- Note that unlike ``core.predict``, this function
+ Note that unlike :func:`core.predict`, this function
can modify ``labelmap`` so that metrics like edit distance
are correctly computed, by converting any string labels
in ``labelmap`` with multiple characters
to (mock) single-character labels,
- with ``vak.labels.multi_char_labels_to_single_char``.
+ with :func:`vak.labels.multi_char_labels_to_single_char`.
"""
# ---- pre-conditions ----------------------------------------------------------------------------------------------
for path, path_name in zip(
@@ -190,14 +190,12 @@ def eval_frame_classification_model(
model.load_state_dict_from_path(checkpoint_path)
- # TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691
- if device == "cuda":
- accelerator = "gpu"
- else:
- accelerator = "auto"
-
- trainer_logger = lightning.loggers.TensorBoardLogger(save_dir=output_dir)
- trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger)
+ trainer_logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=output_dir)
+ trainer = lightning.pytorch.Trainer(
+ accelerator=trainer_config["accelerator"],
+ devices=trainer_config["devices"],
+ logger=trainer_logger
+ )
# TODO: check for hasattr(model, test_step) and if so run test
# below, [0] because validate returns list of dicts, length of no. of val loaders
metric_vals = trainer.validate(model, dataloaders=val_loader)[0]
diff --git a/src/vak/eval/parametric_umap.py b/src/vak/eval/parametric_umap.py
index 107d8d844..cf4b13f41 100644
--- a/src/vak/eval/parametric_umap.py
+++ b/src/vak/eval/parametric_umap.py
@@ -8,7 +8,7 @@
from datetime import datetime
import pandas as pd
-import pytorch_lightning as lightning
+import lightning
import torch.utils.data
from .. import models, transforms
@@ -25,8 +25,8 @@ def eval_parametric_umap_model(
output_dir: str | pathlib.Path,
batch_size: int,
num_workers: int,
+ trainer_config: dict,
split: str = "test",
- device: str | None = None,
) -> None:
"""Evaluate a trained model.
@@ -44,15 +44,15 @@ def eval_parametric_umap_model(
Path to location where .csv files with evaluation metrics should be saved.
batch_size : int
Number of samples per batch fed into model.
+ trainer_config: dict
+ Configuration for :class:`lightning.pytorch.Trainer`.
+ Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`.
num_workers : int
Number of processes to use for parallel loading of data.
Argument to torch.DataLoader. Default is 2.
split : str
Split of dataset on which model should be evaluated.
One of {'train', 'val', 'test'}. Default is 'test'.
- device : str
- Device on which to work with model + data.
- Defaults to 'cuda' if torch.cuda.is_available is True.
"""
# ---- pre-conditions ----------------------------------------------------------------------------------------------
for path, path_name in zip(
@@ -111,14 +111,12 @@ def eval_parametric_umap_model(
model.load_state_dict_from_path(checkpoint_path)
- # TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691
- if device == "cuda":
- accelerator = "gpu"
- else:
- accelerator = "auto"
-
- trainer_logger = lightning.loggers.TensorBoardLogger(save_dir=output_dir)
- trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger)
+ trainer_logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=output_dir)
+ trainer = lightning.pytorch.Trainer(
+ accelerator=trainer_config["accelerator"],
+ devices=trainer_config["devices"],
+ logger=trainer_logger
+ )
# TODO: check for hasattr(model, test_step) and if so run test
# below, [0] because validate returns list of dicts, length of no. of val loaders
metric_vals = trainer.validate(model, dataloaders=val_loader)[0]
diff --git a/src/vak/learncurve/frame_classification.py b/src/vak/learncurve/frame_classification.py
index 8ca2e11b7..fc95d59bb 100644
--- a/src/vak/learncurve/frame_classification.py
+++ b/src/vak/learncurve/frame_classification.py
@@ -19,6 +19,7 @@
def learning_curve_for_frame_classification_model(
model_config: dict,
dataset_config: dict,
+ trainer_config: dict,
batch_size: int,
num_epochs: int,
num_workers: int,
@@ -29,7 +30,6 @@ def learning_curve_for_frame_classification_model(
val_step: int | None = None,
ckpt_step: int | None = None,
patience: int | None = None,
- device: str | None = None,
) -> None:
"""Generate results for a learning curve,
where model performance is measured as a
@@ -49,6 +49,9 @@ def learning_curve_for_frame_classification_model(
dataset_config: dict
Dataset configuration in a :class:`dict`.
Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`.
+ trainer_config: dict
+ Configuration for :class:`lightning.pytorch.Trainer`.
+ Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`.
dataset_path : str
path to where dataset was saved as a csv.
batch_size : int
@@ -188,6 +191,7 @@ def learning_curve_for_frame_classification_model(
train_frame_classification_model(
model_config,
dataset_config,
+ trainer_config,
batch_size,
num_epochs,
num_workers,
@@ -197,7 +201,6 @@ def learning_curve_for_frame_classification_model(
val_step=val_step,
ckpt_step=ckpt_step,
patience=patience,
- device=device,
subset=subset,
)
@@ -238,6 +241,7 @@ def learning_curve_for_frame_classification_model(
eval_frame_classification_model(
model_config,
dataset_config,
+ trainer_config,
ckpt_path,
labelmap_path,
results_path_this_replicate,
@@ -245,7 +249,6 @@ def learning_curve_for_frame_classification_model(
"test",
spect_scaler_path,
post_tfm_kwargs,
- device,
)
# ---- make a csv for analysis -------------------------------------------------------------------------------------
diff --git a/src/vak/learncurve/learncurve.py b/src/vak/learncurve/learncurve.py
index 0b6e443bf..def4b722f 100644
--- a/src/vak/learncurve/learncurve.py
+++ b/src/vak/learncurve/learncurve.py
@@ -15,6 +15,7 @@
def learning_curve(
model_config: dict,
dataset_config: dict,
+ trainer_config: dict,
batch_size: int,
num_epochs: int,
num_workers: int,
@@ -25,7 +26,6 @@ def learning_curve(
val_step: int | None = None,
ckpt_step: int | None = None,
patience: int | None = None,
- device: str | None = None,
) -> None:
"""Generate results for a learning curve,
where model performance is measured as a
@@ -45,6 +45,9 @@ def learning_curve(
dataset_config: dict
Dataset configuration in a :class:`dict`.
Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`.
+ trainer_config: dict
+ Configuration for :class:`lightning.pytorch.Trainer`.
+ Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`.
batch_size : int
number of samples per batch presented to models during training.
num_epochs : int
@@ -74,10 +77,6 @@ def learning_curve(
a float value for ``min_segment_dur``.
See the docstring of the transform for more details on
these arguments and how they work.
- device : str
- Device on which to work with model + data.
- Default is None. If None, then a device will be selected with vak.device.get_default.
- That function defaults to 'cuda' if torch.cuda.is_available is True.
shuffle: bool
if True, shuffle training data before each epoch. Default is True.
normalize_spectrograms : bool
@@ -118,6 +117,7 @@ def learning_curve(
learning_curve_for_frame_classification_model(
model_config=model_config,
dataset_config=dataset_config,
+ trainer_config=trainer_config,
batch_size=batch_size,
num_epochs=num_epochs,
num_workers=num_workers,
@@ -128,7 +128,6 @@ def learning_curve(
val_step=val_step,
ckpt_step=ckpt_step,
patience=patience,
- device=device,
)
else:
raise ValueError(f"Model family not recognized: {model_family}")
diff --git a/src/vak/models/base.py b/src/vak/models/base.py
index fd27b7ae3..42542099f 100644
--- a/src/vak/models/base.py
+++ b/src/vak/models/base.py
@@ -7,14 +7,14 @@
import inspect
from typing import Callable, ClassVar
-import pytorch_lightning as lightning
+import lightning
import torch
from .definition import ModelDefinition
from .definition import validate as validate_definition
-class Model(lightning.LightningModule):
+class Model(lightning.pytorch.LightningModule):
"""Base class for a model in ``vak``,
that other families of models should subclass.
@@ -286,7 +286,7 @@ def load_state_dict_from_path(self, ckpt_path):
in that chekcpoint.
This method allows loading a state dict into an instance.
- It's necessary because `lightning.LightningModule.load`` is a
+ It's necessary because `lightning.pytorch.LightningModule.load`` is a
``classmethod``, so calling that method will trigger
``LightningModule.__init__`` instead of running
``vak.models.Model.__init__``.
@@ -297,7 +297,7 @@ def load_state_dict_from_path(self, ckpt_path):
Path to a checkpoint saved by a model in ``vak``.
This checkpoint has the same key-value pairs as
any other checkpoint saved by a
- ``lightning.LightningModule``.
+ ``lightning.pytorch.LightningModule``.
Returns
-------
diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py
index 305018e8c..0b1777e80 100644
--- a/src/vak/models/frame_classification_model.py
+++ b/src/vak/models/frame_classification_model.py
@@ -160,7 +160,7 @@ def __init__(
def configure_optimizers(self):
"""Returns the model's optimizer.
- Method required by ``lightning.LightningModule``.
+ Method required by ``lightning.pytorch.LightningModule``.
This method returns the ``optimizer`` instance passed into ``__init__``.
If None was passed in, an instance that was created
with default arguments will be returned.
@@ -185,7 +185,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
def training_step(self, batch: tuple, batch_idx: int):
"""Perform one training step.
- Method required by ``lightning.LightningModule``.
+ Method required by ``lightning.pytorch.LightningModule``.
Parameters
----------
@@ -209,7 +209,7 @@ def training_step(self, batch: tuple, batch_idx: int):
def validation_step(self, batch: tuple, batch_idx: int):
"""Perform one validation step.
- Method required by ``lightning.LightningModule``.
+ Method required by ``lightning.pytorch.LightningModule``.
Logs metrics using ``self.log``
Parameters
@@ -322,7 +322,7 @@ def validation_step(self, batch: tuple, batch_idx: int):
def predict_step(self, batch: tuple, batch_idx: int):
"""Perform one prediction step.
- Method required by ``lightning.LightningModule``.
+ Method required by ``lightning.pytorch.LightningModule``.
Parameters
----------
diff --git a/src/vak/models/parametric_umap_model.py b/src/vak/models/parametric_umap_model.py
index 4d2c2cb94..5abaf7bbf 100644
--- a/src/vak/models/parametric_umap_model.py
+++ b/src/vak/models/parametric_umap_model.py
@@ -11,7 +11,7 @@
import pathlib
from typing import Callable, ClassVar, Type
-import pytorch_lightning as lightning
+import lightning
import torch
import torch.utils.data
@@ -126,7 +126,7 @@ def from_config(cls, config: dict):
)
-class ParametricUMAPDatamodule(lightning.LightningDataModule):
+class ParametricUMAPDatamodule(lightning.pytorch.LightningDataModule):
def __init__(
self,
dataset,
@@ -178,7 +178,7 @@ def __init__(
def fit(
self,
- trainer: lightning.Trainer,
+ trainer: lightning.pytorch.Trainer,
dataset_path: str | pathlib.Path,
transform=None,
):
diff --git a/src/vak/predict/frame_classification.py b/src/vak/predict/frame_classification.py
index 765fb3134..e5a3cb6f7 100644
--- a/src/vak/predict/frame_classification.py
+++ b/src/vak/predict/frame_classification.py
@@ -10,13 +10,12 @@
import crowsetta
import joblib
import numpy as np
-import pytorch_lightning as lightning
+import lightning
import torch.utils.data
from tqdm import tqdm
from .. import datasets, models, transforms
from ..common import constants, files, validators
-from ..common.device import get_default as get_default_device
from ..datasets.frame_classification import FramesDataset
logger = logging.getLogger(__name__)
@@ -25,12 +24,12 @@
def predict_with_frame_classification_model(
model_config: dict,
dataset_config: dict,
+ trainer_config: dict,
checkpoint_path,
labelmap_path,
num_workers=2,
timebins_key="t",
spect_scaler_path=None,
- device=None,
annot_csv_filename=None,
output_dir=None,
min_segment_dur=None,
@@ -48,6 +47,9 @@ def predict_with_frame_classification_model(
dataset_config: dict
Dataset configuration in a :class:`dict`.
Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`.
+ trainer_config: dict
+ Configuration for :class:`lightning.pytorch.Trainer`.
+ Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`.
checkpoint_path : str
path to directory with checkpoint files saved by Torch, to reload model
labelmap_path : str
@@ -59,9 +61,6 @@ def predict_with_frame_classification_model(
key for accessing spectrogram in files. Default is 's'.
timebins_key : str
key for accessing vector of time bins in files. Default is 't'.
- device : str
- Device on which to work with model + data.
- Defaults to 'cuda' if torch.cuda.is_available is True.
spect_scaler_path : str
path to a saved SpectScaler object used to normalize spectrograms.
If spectrograms were normalized and this is not provided, will give
@@ -124,9 +123,6 @@ def predict_with_frame_classification_model(
f"value specified for output_dir is not recognized as a directory: {output_dir}"
)
- if device is None:
- device = get_default_device()
-
# ---------------- load data for prediction ------------------------------------------------------------------------
if spect_scaler_path:
logger.info(f"loading SpectScaler from path: {spect_scaler_path}")
@@ -226,14 +222,12 @@ def predict_with_frame_classification_model(
)
model.load_state_dict_from_path(checkpoint_path)
- # TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691
- if device == "cuda":
- accelerator = "gpu"
- else:
- accelerator = "auto"
-
- trainer_logger = lightning.loggers.TensorBoardLogger(save_dir=output_dir)
- trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger)
+ trainer_logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=output_dir)
+ trainer = lightning.pytorch.Trainer(
+ accelerator=trainer_config["accelerator"],
+ devices=trainer_config["devices"],
+ logger=trainer_logger
+ )
logger.info(f"running predict method of {model_name}")
results = trainer.predict(model, pred_loader)
diff --git a/src/vak/predict/parametric_umap.py b/src/vak/predict/parametric_umap.py
index 4e54336f4..331b571b2 100644
--- a/src/vak/predict/parametric_umap.py
+++ b/src/vak/predict/parametric_umap.py
@@ -6,12 +6,11 @@
import os
import pathlib
-import pytorch_lightning as lightning
+import lightning
import torch.utils.data
from .. import datasets, models, transforms
from ..common import validators
-from ..common.device import get_default as get_default_device
from ..datasets.parametric_umap import ParametricUMAPDataset
logger = logging.getLogger(__name__)
@@ -20,12 +19,12 @@
def predict_with_parametric_umap_model(
model_config: dict,
dataset_config: dict,
+ trainer_config: dict,
checkpoint_path,
num_workers=2,
transform_params: dict | None = None,
dataset_params: dict | None = None,
timebins_key="t",
- device=None,
output_dir=None,
):
"""Make predictions on a dataset with a trained
@@ -39,6 +38,9 @@ def predict_with_parametric_umap_model(
dataset_config: dict
Dataset configuration in a :class:`dict`.
Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`.
+ trainer_config: dict
+ Configuration for :class:`lightning.pytorch.Trainer`.
+ Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`.
checkpoint_path : str
path to directory with checkpoint files saved by Torch, to reload model
num_workers : int
@@ -54,9 +56,6 @@ def predict_with_parametric_umap_model(
Optional, default is None.
timebins_key : str
key for accessing vector of time bins in files. Default is 't'.
- device : str
- Device on which to work with model + data.
- Defaults to 'cuda' if torch.cuda.is_available is True.
annot_csv_filename : str
name of .csv file containing predicted annotations.
Default is None, in which case the name of the dataset .csv
@@ -97,9 +96,6 @@ def predict_with_parametric_umap_model(
f"value specified for output_dir is not recognized as a directory: {output_dir}"
)
- if device is None:
- device = get_default_device()
-
# ---------------- load data for prediction ------------------------------------------------------------------------
model_name = model_config["name"]
# TODO: fix this when we build transforms into datasets
@@ -157,14 +153,12 @@ def predict_with_parametric_umap_model(
)
model.load_state_dict_from_path(checkpoint_path)
- # TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691
- if device == "cuda":
- accelerator = "gpu"
- else:
- accelerator = "auto"
-
- trainer_logger = lightning.loggers.TensorBoardLogger(save_dir=output_dir)
- trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger)
+ trainer_logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=output_dir)
+ trainer = lightning.pytorch.Trainer(
+ accelerator=trainer_config["accelerator"],
+ devices=trainer_config["devices"],
+ logger=trainer_logger
+ )
logger.info(f"running predict method of {model_name}")
results = trainer.predict(model, pred_loader) # noqa : F841
diff --git a/src/vak/predict/predict_.py b/src/vak/predict/predict_.py
index 60373b11d..2cc60ea61 100644
--- a/src/vak/predict/predict_.py
+++ b/src/vak/predict/predict_.py
@@ -8,7 +8,7 @@
from .. import models
from ..common import validators
-from ..common.device import get_default as get_default_device
+from ..common.accelerator import get_default as get_default_device
from .frame_classification import predict_with_frame_classification_model
logger = logging.getLogger(__name__)
@@ -17,6 +17,7 @@
def predict(
model_config: dict,
dataset_config: dict,
+ trainer_config: dict,
checkpoint_path: str | pathlib.Path,
labelmap_path: str | pathlib.Path,
num_workers: int = 2,
@@ -37,8 +38,12 @@ def predict(
Model configuration in a ``dict``,
as loaded from a .toml file,
and used by the model method ``from_config``.
- dataset_path : str
- Path to dataset, e.g., a csv file generated by running ``vak prep``.
+ dataset_config: dict
+ Dataset configuration in a :class:`dict`.
+ Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`.
+ trainer_config: dict
+ Configuration for :class:`lightning.pytorch.Trainer`.
+ Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`.
checkpoint_path : str
path to directory with checkpoint files saved by Torch, to reload model
labelmap_path : str
@@ -51,9 +56,6 @@ def predict(
Argument to torch.DataLoader. Default is 2.
timebins_key : str
key for accessing vector of time bins in files. Default is 't'.
- device : str
- Device on which to work with model + data.
- Defaults to 'cuda' if torch.cuda.is_available is True.
spect_scaler_path : str
path to a saved SpectScaler object used to normalize spectrograms.
If spectrograms were normalized and this is not provided, will give
@@ -130,12 +132,12 @@ def predict(
predict_with_frame_classification_model(
model_config=model_config,
dataset_config=dataset_config,
+ trainer_config=trainer_config,
checkpoint_path=checkpoint_path,
labelmap_path=labelmap_path,
num_workers=num_workers,
timebins_key=timebins_key,
spect_scaler_path=spect_scaler_path,
- device=device,
annot_csv_filename=annot_csv_filename,
output_dir=output_dir,
min_segment_dur=min_segment_dur,
diff --git a/src/vak/prep/frame_classification/make_splits.py b/src/vak/prep/frame_classification/make_splits.py
index 2af4b586d..41d521013 100644
--- a/src/vak/prep/frame_classification/make_splits.py
+++ b/src/vak/prep/frame_classification/make_splits.py
@@ -228,7 +228,7 @@ def make_splits(
network model. As returned by :func:`vak.labels.to_map`.
audio_format : str
A :class:`string` representing the format of audio files.
- One of :constant:`vak.common.constants.VALID_AUDIO_FORMATS`.
+ One of :const:`vak.common.constants.VALID_AUDIO_FORMATS`.
spect_key : str
Key for accessing spectrogram in files. Default is 's'.
timebins_key : str
diff --git a/src/vak/train/frame_classification.py b/src/vak/train/frame_classification.py
index 25e07a062..e2d9074f6 100644
--- a/src/vak/train/frame_classification.py
+++ b/src/vak/train/frame_classification.py
@@ -14,7 +14,6 @@
from .. import datasets, models, transforms
from ..common import validators
-from ..common.device import get_default as get_default_device
from ..common.trainer import get_default_trainer
from ..datasets.frame_classification import FramesDataset, WindowDataset
@@ -29,6 +28,7 @@ def get_split_dur(df: pd.DataFrame, split: str) -> float:
def train_frame_classification_model(
model_config: dict,
dataset_config: dict,
+ trainer_config: dict,
batch_size: int,
num_epochs: int,
num_workers: int,
@@ -40,7 +40,6 @@ def train_frame_classification_model(
val_step: int | None = None,
ckpt_step: int | None = None,
patience: int | None = None,
- device: str | None = None,
subset: str | None = None,
) -> None:
"""Train a model from the frame classification family
@@ -60,6 +59,9 @@ def train_frame_classification_model(
dataset_config: dict
Dataset configuration in a :class:`dict`.
Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`.
+ trainer_config: dict
+ Configuration for :class:`lightning.pytorch.Trainer` in a :class:`dict`.
+ Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`.
batch_size : int
number of samples per batch presented to models during training.
num_epochs : int
@@ -86,10 +88,6 @@ def train_frame_classification_model(
results_path : str, pathlib.Path
Directory where results will be saved.
If specified, this parameter overrides ``root_results_dir``.
- device : str
- Device on which to work with model + data.
- Default is None. If None, then a device will be selected with vak.split.get_default.
- That function defaults to 'cuda' if torch.cuda.is_available is True.
shuffle: bool
if True, shuffle training data before each epoch. Default is True.
normalize_spectrograms : bool
@@ -279,9 +277,6 @@ def train_frame_classification_model(
else:
val_loader = None
- if device is None:
- device = get_default_device()
-
model = models.get(
model_name,
model_config,
@@ -308,11 +303,12 @@ def train_frame_classification_model(
"patience": patience,
}
trainer = get_default_trainer(
+ accelerator=trainer_config["accelerator"],
+ devices=trainer_config["devices"],
max_steps=max_steps,
log_save_dir=results_model_root,
val_step=val_step,
default_callback_kwargs=default_callback_kwargs,
- device=device,
)
train_time_start = datetime.datetime.now()
logger.info(f"Training start time: {train_time_start.isoformat()}")
diff --git a/src/vak/train/parametric_umap.py b/src/vak/train/parametric_umap.py
index f9180e5c0..0fd72ca33 100644
--- a/src/vak/train/parametric_umap.py
+++ b/src/vak/train/parametric_umap.py
@@ -7,12 +7,11 @@
import pathlib
import pandas as pd
-import pytorch_lightning as lightning
+import lightning
import torch.utils.data
from .. import datasets, models, transforms
from ..common import validators
-from ..common.device import get_default as get_default_device
from ..common.paths import generate_results_dir_name_as_path
from ..datasets.parametric_umap import ParametricUMAPDataset
@@ -25,22 +24,17 @@ def get_split_dur(df: pd.DataFrame, split: str) -> float:
def get_trainer(
+ accelerator: str,
+ devices: int | list[int],
max_epochs: int,
ckpt_root: str | pathlib.Path,
ckpt_step: int,
log_save_dir: str | pathlib.Path,
- device: str = "cuda",
-) -> lightning.Trainer:
- """Returns an instance of ``lightning.Trainer``
+) -> lightning.pytorch.Trainer:
+ """Returns an instance of ``lightning.pytorch.Trainer``
with a default set of callbacks.
Used by ``vak.core`` functions."""
- # TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691
- if device == "cuda":
- accelerator = "gpu"
- else:
- accelerator = "auto"
-
- ckpt_callback = lightning.callbacks.ModelCheckpoint(
+ ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint(
dirpath=ckpt_root,
filename="checkpoint",
every_n_train_steps=ckpt_step,
@@ -50,7 +44,7 @@ def get_trainer(
ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint"
ckpt_callback.FILE_EXTENSION = ".pt"
- val_ckpt_callback = lightning.callbacks.ModelCheckpoint(
+ val_ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint(
monitor="val_loss",
dirpath=ckpt_root,
save_top_k=1,
@@ -66,11 +60,12 @@ def get_trainer(
val_ckpt_callback,
]
- logger = lightning.loggers.TensorBoardLogger(save_dir=log_save_dir)
+ logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=log_save_dir)
- trainer = lightning.Trainer(
+ trainer = lightning.pytorch.Trainer(
max_epochs=max_epochs,
accelerator=accelerator,
+ devices=devices,
logger=logger,
callbacks=callbacks,
)
@@ -80,6 +75,7 @@ def get_trainer(
def train_parametric_umap_model(
model_config: dict,
dataset_config: dict,
+ trainer_config: dict,
batch_size: int,
num_epochs: int,
num_workers: int,
@@ -89,7 +85,6 @@ def train_parametric_umap_model(
shuffle: bool = True,
val_step: int | None = None,
ckpt_step: int | None = None,
- device: str | None = None,
subset: str | None = None,
) -> None:
"""Train a model from the parametric UMAP family
@@ -109,6 +104,9 @@ def train_parametric_umap_model(
dataset_config: dict
Dataset configuration in a :class:`dict`.
Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`.
+ trainer_config: dict
+ Configuration for :class:`lightning.pytorch.Trainer` in a :class:`dict`.
+ Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`.
batch_size : int
number of samples per batch presented to models during training.
num_epochs : int
@@ -137,10 +135,6 @@ def train_parametric_umap_model(
If ckpt_step is n, then a checkpoint is saved every time
the global step / n is a whole number, i.e., when ckpt_step modulo the global step is 0.
Default is None, in which case checkpoint is only saved at the last epoch.
- device : str
- Device on which to work with model + data.
- Default is None. If None, then a device will be selected with vak.split.get_default.
- That function defaults to 'cuda' if torch.cuda.is_available is True.
shuffle: bool
if True, shuffle training data before each epoch. Default is True.
split : str
@@ -248,9 +242,6 @@ def train_parametric_umap_model(
else:
val_loader = None
- if device is None:
- device = get_default_device()
-
model = models.get(
model_name,
model_config,
@@ -269,9 +260,10 @@ def train_parametric_umap_model(
ckpt_root.mkdir()
logger.info(f"training {model_name}")
trainer = get_trainer(
+ accelerator=trainer_config["accelerator"],
+ devices=trainer_config["devices"],
max_epochs=num_epochs,
log_save_dir=results_model_root,
- device=device,
ckpt_root=ckpt_root,
ckpt_step=ckpt_step,
)
diff --git a/src/vak/train/train_.py b/src/vak/train/train_.py
index 96926967d..1a5527c88 100644
--- a/src/vak/train/train_.py
+++ b/src/vak/train/train_.py
@@ -16,6 +16,7 @@
def train(
model_config: dict,
dataset_config: dict,
+ trainer_config: dict,
batch_size: int,
num_epochs: int,
num_workers: int,
@@ -27,7 +28,6 @@ def train(
val_step: int | None = None,
ckpt_step: int | None = None,
patience: int | None = None,
- device: str | None = None,
subset: str | None = None,
):
"""Train a model and save results.
@@ -44,6 +44,9 @@ def train(
dataset_config: dict
Dataset configuration in a :class:`dict`.
Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`.
+ trainer_config: dict
+ Configuration for :class:`lightning.pytorch.Trainer` in a :class:`dict`.
+ Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`.
window_size : int
size of windows taken from spectrograms, in number of time bins,
shown to neural networks
@@ -147,6 +150,7 @@ def train(
train_frame_classification_model(
model_config=model_config,
dataset_config=dataset_config,
+ trainer_config=trainer_config,
batch_size=batch_size,
num_epochs=num_epochs,
num_workers=num_workers,
@@ -158,13 +162,13 @@ def train(
val_step=val_step,
ckpt_step=ckpt_step,
patience=patience,
- device=device,
subset=subset,
)
elif model_family == "ParametricUMAPModel":
train_parametric_umap_model(
model_config=model_config,
dataset_config=dataset_config,
+ trainer_config=trainer_config,
batch_size=batch_size,
num_epochs=num_epochs,
num_workers=num_workers,
@@ -173,7 +177,6 @@ def train(
shuffle=shuffle,
val_step=val_step,
ckpt_step=ckpt_step,
- device=device,
subset=subset,
)
else:
diff --git a/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml
index b38979f48..32940a679 100644
--- a/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml
+++ b/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml
@@ -17,7 +17,7 @@ transform_type = "log_spect_plus_one"
checkpoint_path = "tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/ConvEncoderUMAP/results_230727_210112/ConvEncoderUMAP/checkpoints/checkpoint.pt"
batch_size = 64
num_workers = 16
-device = "cuda"
+
output_dir = "./tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/ConvEncoderUMAP"
[vak.eval.model.ConvEncoderUMAP.network]
@@ -31,3 +31,7 @@ n_components = 2
[vak.eval.model.ConvEncoderUMAP.optimizer]
lr = 0.001
+
+[vak.eval.trainer]
+accelerator = "gpu"
+devices = [0]
\ No newline at end of file
diff --git a/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml
index 8be5a4d3a..f188b650c 100644
--- a/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml
+++ b/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml
@@ -21,7 +21,7 @@ num_epochs = 1
val_step = 1
ckpt_step = 1000
num_workers = 16
-device = "cuda"
+
root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/ConvEncoderUMAP"
[vak.train.model.ConvEncoderUMAP.network]
@@ -35,3 +35,7 @@ n_components = 2
[vak.train.model.ConvEncoderUMAP.optimizer]
lr = 0.001
+
+[vak.train.trainer]
+accelerator = "gpu"
+devices = [0]
diff --git a/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml
index 12bfcba84..98e51a7b0 100644
--- a/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml
+++ b/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml
@@ -19,7 +19,7 @@ checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRep
labelmap_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json"
batch_size = 11
num_workers = 16
-device = "cuda"
+
spect_scaler_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect"
output_dir = "./tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet"
@@ -43,3 +43,7 @@ hidden_size = 32
[vak.eval.model.TweetyNet.optimizer]
lr = 0.001
+
+[vak.eval.trainer]
+accelerator = "gpu"
+devices = [0]
diff --git a/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml
index 59868a28a..744cec82e 100644
--- a/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml
+++ b/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml
@@ -27,7 +27,7 @@ val_step = 50
ckpt_step = 200
patience = 4
num_workers = 16
-device = "cuda"
+
root_results_dir = "./tests/data_for_tests/generated/results/learncurve/audio_cbin_annot_notmat/TweetyNet"
[vak.learncurve.post_tfm_kwargs]
@@ -50,3 +50,7 @@ hidden_size = 32
[vak.learncurve.model.TweetyNet.optimizer]
lr = 0.001
+
+[vak.learncurve.trainer]
+accelerator = "gpu"
+devices = [0]
diff --git a/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml
index 3d794f314..0d1cd8f13 100644
--- a/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml
+++ b/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml
@@ -18,7 +18,7 @@ checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRep
labelmap_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json"
batch_size = 11
num_workers = 16
-device = "cuda"
+
output_dir = "./tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet"
annot_csv_filename = "bl26lb16.041912.annot.csv"
@@ -38,3 +38,7 @@ hidden_size = 32
[vak.predict.model.TweetyNet.optimizer]
lr = 0.001
+
+[vak.predict.trainer]
+accelerator = "gpu"
+devices = [0]
diff --git a/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml
index 9b751e7f0..c5d26bf88 100644
--- a/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml
+++ b/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml
@@ -25,7 +25,7 @@ val_step = 50
ckpt_step = 200
patience = 4
num_workers = 16
-device = "cuda"
+
root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet"
[vak.train.dataset]
@@ -44,3 +44,7 @@ hidden_size = 32
[vak.train.model.TweetyNet.optimizer]
lr = 0.001
+
+[vak.train.trainer]
+accelerator = "gpu"
+devices = [0]
diff --git a/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml
index c7ca91a96..5af7f78d2 100644
--- a/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml
+++ b/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml
@@ -25,7 +25,7 @@ val_step = 50
ckpt_step = 200
patience = 4
num_workers = 16
-device = "cuda"
+
root_results_dir = "./tests/data_for_tests/generated/results/train_continue/audio_cbin_annot_notmat/TweetyNet"
checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt"
spect_scaler_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect"
@@ -46,3 +46,7 @@ hidden_size = 32
[vak.train.model.TweetyNet.optimizer]
lr = 0.001
+
+[vak.train.trainer]
+accelerator = "gpu"
+devices = [0]
diff --git a/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml
index c66e9c34d..05d897eba 100644
--- a/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml
+++ b/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml
@@ -25,7 +25,7 @@ val_step = 50
ckpt_step = 200
patience = 4
num_workers = 16
-device = "cuda"
+
root_results_dir = "./tests/data_for_tests/generated/results/train_continue/spect_mat_annot_yarden/TweetyNet"
checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt"
@@ -45,3 +45,7 @@ hidden_size = 32
[vak.train.model.TweetyNet.optimizer]
lr = 0.001
+
+[vak.train.trainer]
+accelerator = "gpu"
+devices = [0]
diff --git a/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml
index a9aaaf112..4796edb60 100644
--- a/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml
+++ b/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml
@@ -25,7 +25,7 @@ val_step = 50
ckpt_step = 200
patience = 4
num_workers = 16
-device = "cuda"
+
root_results_dir = "./tests/data_for_tests/generated/results/train/spect_mat_annot_yarden/TweetyNet"
[vak.train.dataset]
@@ -44,3 +44,7 @@ hidden_size = 32
[vak.train.model.TweetyNet.optimizer]
lr = 0.001
+
+[vak.train.trainer]
+accelerator = "gpu"
+devices = [0]
diff --git a/tests/data_for_tests/configs/invalid_key_config.toml b/tests/data_for_tests/configs/invalid_key_config.toml
index 0012c6d6c..76770862c 100644
--- a/tests/data_for_tests/configs/invalid_key_config.toml
+++ b/tests/data_for_tests/configs/invalid_key_config.toml
@@ -35,3 +35,7 @@ params = { window_size = 88 }
[vak.train.model.TweetyNet.optimizer]
learning_rate = 0.001
+
+[vak.train.trainer]
+accelerator = "gpu"
+devices = [0]
diff --git a/tests/data_for_tests/configs/invalid_table_config.toml b/tests/data_for_tests/configs/invalid_table_config.toml
index 24998129d..09898375f 100644
--- a/tests/data_for_tests/configs/invalid_table_config.toml
+++ b/tests/data_for_tests/configs/invalid_table_config.toml
@@ -32,3 +32,7 @@ save_only_single_checkpoint_file = true
[vak.train.model.TweetyNet.optimizer]
learning_rate = 0.001
+
+[vak.train.trainer]
+accelerator = "gpu"
+devices = [0]
diff --git a/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml b/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml
index a4fcd542d..3107b538c 100644
--- a/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml
+++ b/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml
@@ -27,9 +27,13 @@ val_step = 50
ckpt_step = 200
patience = 4
num_workers = 16
-device = "cuda"
+
root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat"
+[vak.train.trainer]
+accelerator = "gpu"
+devices = [0]
+
[vak.learncurve]
normalize_spectrograms = true
batch_size = 11
@@ -40,7 +44,7 @@ patience = 4
num_workers = 16
train_set_durs = [ 4, 6 ]
num_replicates = 2
-device = "cuda"
+
root_results_dir = './tests/data_for_tests/generated/results/learncurve/audio_cbin_annot_notmat'
[vak.learncurve.model.TweetyNet.optimizer]
diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py
index ac174ea00..18d506be7 100644
--- a/tests/fixtures/__init__.py
+++ b/tests/fixtures/__init__.py
@@ -6,6 +6,7 @@
from .dataframe import *
from .dataset import *
from .device import *
+from .trainer import *
from .model import *
from .path import *
from .source_files import *
diff --git a/tests/fixtures/trainer.py b/tests/fixtures/trainer.py
new file mode 100644
index 000000000..8d8becd6d
--- /dev/null
+++ b/tests/fixtures/trainer.py
@@ -0,0 +1,19 @@
+import pytest
+import torch
+
+
+TRAINER_TABLE = [
+ {"accelerator": "cpu", "devices": 1}
+]
+if torch.cuda.is_available():
+ {"accelerator": "gpu", "devices": [0]}
+
+
+@pytest.fixture(params=TRAINER_TABLE)
+def trainer_table(request):
+ """Parametrized 'trainer' table for config file
+
+ Causes any test using this :func:`trainer_table` fixture
+ to run just once if only a cpu is available,
+ and twice if ``torch.cuda.is_available()`` returns ``True``."""
+ return request.param
diff --git a/tests/test_cli/test_eval.py b/tests/test_cli/test_eval.py
index 0ee9aba65..5a2cfdd90 100644
--- a/tests/test_cli/test_eval.py
+++ b/tests/test_cli/test_eval.py
@@ -18,7 +18,7 @@
],
)
def test_eval(
- model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, device
+ model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, trainer_table
):
output_dir = tmp_path.joinpath(
f"test_eval_{audio_format}_{spect_format}_{annot_format}"
@@ -27,7 +27,7 @@ def test_eval(
keys_to_change = [
{"table": "eval", "key": "output_dir", "value": str(output_dir)},
- {"table": "eval", "key": "device", "value": device},
+ {"table": "eval", "key": "trainer", "value": trainer_table},
]
toml_path = specific_config_toml_path(
diff --git a/tests/test_cli/test_learncurve.py b/tests/test_cli/test_learncurve.py
index 7fd0a3a8b..c249b4164 100644
--- a/tests/test_cli/test_learncurve.py
+++ b/tests/test_cli/test_learncurve.py
@@ -10,7 +10,7 @@
from . import cli_asserts
-def test_learncurve(specific_config_toml_path, tmp_path, device):
+def test_learncurve(specific_config_toml_path, tmp_path, trainer_table):
root_results_dir = tmp_path.joinpath("test_learncurve_root_results_dir")
root_results_dir.mkdir()
@@ -20,7 +20,7 @@ def test_learncurve(specific_config_toml_path, tmp_path, device):
"key": "root_results_dir",
"value": str(root_results_dir),
},
- {"table": "learncurve", "key": "device", "value": device},
+ {"table": "learncurve", "key": "trainer", "value": trainer_table},
]
toml_path = specific_config_toml_path(
diff --git a/tests/test_cli/test_predict.py b/tests/test_cli/test_predict.py
index 30c78d3c5..d7766bcff 100644
--- a/tests/test_cli/test_predict.py
+++ b/tests/test_cli/test_predict.py
@@ -17,7 +17,7 @@
],
)
def test_predict(
- model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, device
+ model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, trainer_table
):
output_dir = tmp_path.joinpath(
f"test_predict_{audio_format}_{spect_format}_{annot_format}"
@@ -26,7 +26,7 @@ def test_predict(
keys_to_change = [
{"table": "predict", "key": "output_dir", "value": str(output_dir)},
- {"table": "predict", "key": "device", "value": device},
+ {"table": "predict", "key": "trainer", "value": trainer_table},
]
toml_path = specific_config_toml_path(
diff --git a/tests/test_cli/test_train.py b/tests/test_cli/test_train.py
index a23acab3c..21c699683 100644
--- a/tests/test_cli/test_train.py
+++ b/tests/test_cli/test_train.py
@@ -19,7 +19,7 @@
],
)
def test_train(
- model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, device
+ model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, trainer_table
):
root_results_dir = tmp_path.joinpath("test_train_root_results_dir")
root_results_dir.mkdir()
@@ -30,7 +30,7 @@ def test_train(
"key": "root_results_dir",
"value": str(root_results_dir),
},
- {"table": "train", "key": "device", "value": device},
+ {"table": "train", "key": "trainer", "value": trainer_table},
]
toml_path = specific_config_toml_path(
diff --git a/tests/test_config/test_eval.py b/tests/test_config/test_eval.py
index de0ac681e..7917e4a6b 100644
--- a/tests/test_config/test_eval.py
+++ b/tests/test_config/test_eval.py
@@ -14,7 +14,7 @@ class TestEval:
'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json',
'batch_size': 11,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect',
'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet',
'post_tfm_kwargs': {
@@ -45,6 +45,7 @@ class TestEval:
def test_init(self, config_dict):
config_dict['model'] = vak.config.ModelConfig.from_config_dict(config_dict['model'])
config_dict['dataset'] = vak.config.DatasetConfig.from_config_dict(config_dict['dataset'])
+ config_dict['trainer'] = vak.config.TrainerConfig(**config_dict['trainer'])
eval_config = vak.config.EvalConfig(**config_dict)
@@ -58,7 +59,7 @@ def test_init(self, config_dict):
'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json',
'batch_size': 11,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect',
'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet',
'post_tfm_kwargs': {
@@ -108,7 +109,7 @@ def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict):
'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json',
'batch_size': 11,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect',
'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet',
'post_tfm_kwargs': {
@@ -127,7 +128,7 @@ def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict):
'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json',
'batch_size': 11,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect',
'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet',
'post_tfm_kwargs': {
@@ -158,7 +159,7 @@ def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict):
'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json',
'batch_size': 11,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect',
'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet',
'post_tfm_kwargs': {
@@ -193,7 +194,7 @@ def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict):
'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json',
'batch_size': 11,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect',
'post_tfm_kwargs': {
'majority_vote': True, 'min_segment_dur': 0.02
diff --git a/tests/test_config/test_learncurve.py b/tests/test_config/test_learncurve.py
index 6d2d65270..fa21c152b 100644
--- a/tests/test_config/test_learncurve.py
+++ b/tests/test_config/test_learncurve.py
@@ -17,7 +17,7 @@ class TestLearncurveConfig:
'ckpt_step': 200,
'patience': 4,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet',
'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02},
'model': {
@@ -47,6 +47,7 @@ class TestLearncurveConfig:
def test_init(self, config_dict):
config_dict['model'] = vak.config.ModelConfig.from_config_dict(config_dict['model'])
config_dict['dataset'] = vak.config.DatasetConfig.from_config_dict(config_dict['dataset'])
+ config_dict['trainer'] = vak.config.TrainerConfig(**config_dict['trainer'])
learncurve_config = vak.config.LearncurveConfig(**config_dict)
@@ -63,7 +64,7 @@ def test_init(self, config_dict):
'ckpt_step': 200,
'patience': 4,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet',
'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02},
'model': {
@@ -118,7 +119,7 @@ def test_from_config_dict_with_real_config(self, a_generated_learncurve_config_d
'ckpt_step': 200,
'patience': 4,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet',
'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02},
'dataset': {
@@ -137,7 +138,7 @@ def test_from_config_dict_with_real_config(self, a_generated_learncurve_config_d
'ckpt_step': 200,
'patience': 4,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet',
'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02},
'model': {
@@ -171,7 +172,7 @@ def test_from_config_dict_with_real_config(self, a_generated_learncurve_config_d
'ckpt_step': 200,
'patience': 4,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02},
'model': {
'TweetyNet': {
diff --git a/tests/test_config/test_predict.py b/tests/test_config/test_predict.py
index 8d81dcf07..8dc77e7d8 100644
--- a/tests/test_config/test_predict.py
+++ b/tests/test_config/test_predict.py
@@ -15,7 +15,7 @@ class TestPredictConfig:
'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json',
'batch_size': 11,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet',
'annot_csv_filename': 'bl26lb16.041912.annot.csv',
'model': {
@@ -43,6 +43,7 @@ class TestPredictConfig:
def test_init(self, config_dict):
config_dict['model'] = vak.config.ModelConfig.from_config_dict(config_dict['model'])
config_dict['dataset'] = vak.config.DatasetConfig.from_config_dict(config_dict['dataset'])
+ config_dict['trainer'] = vak.config.TrainerConfig(**config_dict['trainer'])
predict_config = vak.config.PredictConfig(**config_dict)
@@ -57,7 +58,7 @@ def test_init(self, config_dict):
'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json',
'batch_size': 11,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet',
'annot_csv_filename': 'bl26lb16.041912.annot.csv',
'model': {
@@ -104,7 +105,7 @@ def test_from_config_dict_with_real_config(self, a_generated_predict_config_dict
'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json',
'batch_size': 11,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet',
'annot_csv_filename': 'bl26lb16.041912.annot.csv',
'model': {
@@ -137,7 +138,7 @@ def test_from_config_dict_with_real_config(self, a_generated_predict_config_dict
'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json',
'batch_size': 11,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet',
'annot_csv_filename': 'bl26lb16.041912.annot.csv',
'model': {
@@ -167,7 +168,7 @@ def test_from_config_dict_with_real_config(self, a_generated_predict_config_dict
'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json',
'batch_size': 11,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet',
'annot_csv_filename': 'bl26lb16.041912.annot.csv',
'dataset': {
diff --git a/tests/test_config/test_train.py b/tests/test_config/test_train.py
index e5a3127da..970b1abb3 100644
--- a/tests/test_config/test_train.py
+++ b/tests/test_config/test_train.py
@@ -17,7 +17,7 @@ class TestTrainConfig:
'ckpt_step': 200,
'patience': 4,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet',
'model': {
'TweetyNet': {
@@ -46,6 +46,7 @@ class TestTrainConfig:
def test_init(self, config_dict):
config_dict['model'] = vak.config.ModelConfig.from_config_dict(config_dict['model'])
config_dict['dataset'] = vak.config.DatasetConfig.from_config_dict(config_dict['dataset'])
+ config_dict['trainer'] = vak.config.TrainerConfig(**config_dict['trainer'])
train_config = vak.config.TrainConfig(**config_dict)
@@ -62,7 +63,7 @@ def test_init(self, config_dict):
'ckpt_step': 200,
'patience': 4,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet',
'model': {
'TweetyNet': {
@@ -112,7 +113,7 @@ def test_from_config_dict_with_real_config(self, a_generated_train_config_dict):
'ckpt_step': 200,
'patience': 4,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet',
'dataset': {
'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819'
@@ -129,7 +130,7 @@ def test_from_config_dict_with_real_config(self, a_generated_train_config_dict):
'ckpt_step': 200,
'patience': 4,
'num_workers': 16,
- 'device': 'cuda',
+ 'trainer': {'accelerator': 'gpu', 'devices': [0]},
'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet',
'model': {
'TweetyNet': {
diff --git a/tests/test_config/test_trainer.py b/tests/test_config/test_trainer.py
new file mode 100644
index 000000000..649bae96e
--- /dev/null
+++ b/tests/test_config/test_trainer.py
@@ -0,0 +1,138 @@
+import pytest
+
+import vak.config.trainer
+
+
+class TestTrainerConfig:
+
+ @pytest.mark.parametrize(
+ 'config_dict',
+ [
+ {
+ 'accelerator': 'cpu',
+ },
+ {
+ 'accelerator': 'gpu',
+ 'devices': [0],
+ },
+ {
+ 'accelerator': 'gpu',
+ 'devices': [1],
+ },
+ {
+ 'accelerator': 'cpu',
+ 'devices': 1,
+ },
+ {
+ 'devices': 1,
+ },
+ ]
+ )
+ def test_init(self, config_dict):
+ trainer_config = vak.config.trainer.TrainerConfig(**config_dict)
+
+ assert isinstance(trainer_config, vak.config.trainer.TrainerConfig)
+ if 'accelerator' in config_dict:
+ assert getattr(trainer_config, 'accelerator') == config_dict['accelerator']
+ else:
+ # TODO: mock `accelerator.get_default` here, return either 'cpu' or 'gpu'
+ assert getattr(trainer_config, 'accelerator') == vak.common.accelerator.get_default()
+ if 'devices' in config_dict:
+ assert getattr(trainer_config, 'devices') == config_dict['devices']
+ else:
+ if 'accelerator' == 'cpu':
+ assert getattr(trainer_config, 'devices') == 1
+ elif 'accelerator' in ('gpu', 'tpu', 'ipu'):
+ assert getattr(trainer_config, 'devices') == [0]
+
+ @pytest.mark.parametrize(
+ 'config_dict, expected_exception',
+ [
+ # 'accelerator' can't be 'auto', breaks train/eval/prep/learncurve functions
+ (
+ {
+ 'accelerator': 'auto',
+ },
+ ValueError,
+ ),
+ # throws a device because parallel across GPUs won't work right now
+ (
+ {
+ 'accelerator': 'gpu',
+ 'devices': [0, 1],
+ },
+ ValueError,
+ ),
+ # 'devices' can't be -1, breaks train/eval/prep/learncurve functions
+ (
+ {
+ 'accelerator': 'gpu',
+ 'devices': -1,
+ },
+ ValueError,
+ ),
+ (
+ {
+ 'accelerator': 'gpu',
+ 'devices': 'auto',
+ },
+ ValueError,
+ ),
+ # when accelerator is CPU, devices must be int gt 0
+ (
+ {
+ 'accelerator': 'cpu',
+ 'devices': -1,
+ },
+ ValueError,
+ )
+ ]
+ )
+ def test_init_raises(self, config_dict, expected_exception):
+ with pytest.raises(expected_exception):
+ vak.config.trainer.TrainerConfig(**config_dict)
+
+ @pytest.mark.parametrize(
+ 'config_dict',
+ [
+ {
+ 'accelerator': 'cpu',
+ },
+ {
+ 'accelerator': 'gpu',
+ },
+ {
+ 'accelerator': 'gpu',
+ 'devices': [0],
+ },
+ {
+ 'accelerator': 'gpu',
+ 'devices': [1],
+ },
+ {
+ 'accelerator': 'cpu',
+ 'devices': 1,
+ },
+ {
+ 'devices': 1,
+ },
+ ]
+ )
+ def test_asdict(self, config_dict):
+ trainer_config = vak.config.trainer.TrainerConfig(**config_dict)
+
+ trainer_config_asdict = trainer_config.asdict()
+
+ assert isinstance(trainer_config_asdict, dict)
+ if 'accelerator' in config_dict:
+ assert trainer_config_asdict['accelerator'] == config_dict['accelerator']
+ else:
+ # TODO: mock `accelerator.get_default` here, return either 'cpu' or 'gpu'
+ assert trainer_config_asdict['accelerator'] == vak.common.accelerator.get_default()
+ if 'devices' in config_dict:
+ assert trainer_config_asdict['devices'] == config_dict['devices']
+ else:
+ if config_dict["accelerator"] == 'cpu':
+ assert trainer_config_asdict['devices'] == 1
+ elif config_dict["accelerator"] == 'gpu':
+ assert trainer_config_asdict['devices'] == [0]
diff --git a/tests/test_eval/test_eval.py b/tests/test_eval/test_eval.py
index 7a874afb9..21d543828 100644
--- a/tests/test_eval/test_eval.py
+++ b/tests/test_eval/test_eval.py
@@ -31,7 +31,7 @@ def test_eval(
keys_to_change = [
{"table": "eval", "key": "output_dir", "value": str(output_dir)},
- {"table": "eval", "key": "device", "value": 'cpu'},
+ {"table": "eval", "key": "trainer", "value": {"accelerator": "cpu", "devices": 1}},
]
toml_path = specific_config_toml_path(
@@ -51,13 +51,13 @@ def test_eval(
vak.eval.eval(
model_config=cfg.eval.model.asdict(),
dataset_config=cfg.eval.dataset.asdict(),
+ trainer_config=cfg.eval.trainer.asdict(),
checkpoint_path=cfg.eval.checkpoint_path,
labelmap_path=cfg.eval.labelmap_path,
output_dir=cfg.eval.output_dir,
num_workers=cfg.eval.num_workers,
batch_size=cfg.eval.batch_size,
spect_scaler_path=cfg.eval.spect_scaler_path,
- device=cfg.eval.device,
post_tfm_kwargs=cfg.eval.post_tfm_kwargs,
)
diff --git a/tests/test_eval/test_frame_classification.py b/tests/test_eval/test_frame_classification.py
index ee55825a5..40eb04c07 100644
--- a/tests/test_eval/test_frame_classification.py
+++ b/tests/test_eval/test_frame_classification.py
@@ -46,7 +46,7 @@ def test_eval_frame_classification_model(
annot_format,
specific_config_toml_path,
tmp_path,
- device,
+ trainer_table,
post_tfm_kwargs
):
output_dir = tmp_path.joinpath(
@@ -56,7 +56,7 @@ def test_eval_frame_classification_model(
keys_to_change = [
{"table": "eval", "key": "output_dir", "value": str(output_dir)},
- {"table": "eval", "key": "device", "value": device},
+ {"table": "eval", "key": "trainer", "value": trainer_table},
]
toml_path = specific_config_toml_path(
@@ -72,12 +72,12 @@ def test_eval_frame_classification_model(
vak.eval.frame_classification.eval_frame_classification_model(
model_config=cfg.eval.model.asdict(),
dataset_config=cfg.eval.dataset.asdict(),
+ trainer_config=cfg.eval.trainer.asdict(),
checkpoint_path=cfg.eval.checkpoint_path,
labelmap_path=cfg.eval.labelmap_path,
output_dir=cfg.eval.output_dir,
num_workers=cfg.eval.num_workers,
spect_scaler_path=cfg.eval.spect_scaler_path,
- device=cfg.eval.device,
post_tfm_kwargs=post_tfm_kwargs,
)
@@ -96,7 +96,7 @@ def test_eval_frame_classification_model_raises_file_not_found(
path_option_to_change,
specific_config_toml_path,
tmp_path,
- device
+ trainer_table
):
"""Test that core.eval raises FileNotFoundError
when one of the following does not exist:
@@ -109,7 +109,7 @@ def test_eval_frame_classification_model_raises_file_not_found(
keys_to_change = [
{"table": "eval", "key": "output_dir", "value": str(output_dir)},
- {"table": "eval", "key": "device", "value": device},
+ {"table": "eval", "key": "trainer", "value": trainer_table},
path_option_to_change,
]
@@ -126,12 +126,12 @@ def test_eval_frame_classification_model_raises_file_not_found(
vak.eval.frame_classification.eval_frame_classification_model(
model_config=cfg.eval.model.asdict(),
dataset_config=cfg.eval.dataset.asdict(),
+ trainer_config=cfg.eval.trainer.asdict(),
checkpoint_path=cfg.eval.checkpoint_path,
labelmap_path=cfg.eval.labelmap_path,
output_dir=cfg.eval.output_dir,
num_workers=cfg.eval.num_workers,
spect_scaler_path=cfg.eval.spect_scaler_path,
- device=cfg.eval.device,
)
@@ -145,7 +145,7 @@ def test_eval_frame_classification_model_raises_file_not_found(
def test_eval_frame_classification_model_raises_not_a_directory(
path_option_to_change,
specific_config_toml_path,
- device,
+ trainer_table,
tmp_path,
):
"""Test that core.eval raises NotADirectory
@@ -153,7 +153,7 @@ def test_eval_frame_classification_model_raises_not_a_directory(
"""
keys_to_change = [
path_option_to_change,
- {"table": "eval", "key": "device", "value": device},
+ {"table": "eval", "key": "trainer", "value": trainer_table},
]
if path_option_to_change["key"] != "output_dir":
@@ -180,10 +180,10 @@ def test_eval_frame_classification_model_raises_not_a_directory(
vak.eval.frame_classification.eval_frame_classification_model(
model_config=cfg.eval.model.asdict(),
dataset_config=cfg.eval.dataset.asdict(),
+ trainer_config=cfg.eval.trainer.asdict(),
checkpoint_path=cfg.eval.checkpoint_path,
labelmap_path=cfg.eval.labelmap_path,
output_dir=cfg.eval.output_dir,
num_workers=cfg.eval.num_workers,
spect_scaler_path=cfg.eval.spect_scaler_path,
- device=cfg.eval.device,
)
diff --git a/tests/test_eval/test_parametric_umap.py b/tests/test_eval/test_parametric_umap.py
index 4c6c7e573..b1744291d 100644
--- a/tests/test_eval/test_parametric_umap.py
+++ b/tests/test_eval/test_parametric_umap.py
@@ -25,7 +25,7 @@ def test_eval_parametric_umap_model(
annot_format,
specific_config_toml_path,
tmp_path,
- device,
+ trainer_table,
):
output_dir = tmp_path.joinpath(
f"test_eval_{audio_format}_{spect_format}_{annot_format}"
@@ -34,7 +34,7 @@ def test_eval_parametric_umap_model(
keys_to_change = [
{"table": "eval", "key": "output_dir", "value": str(output_dir)},
- {"table": "eval", "key": "device", "value": device},
+ {"table": "eval", "key": "trainer", "value": trainer_table},
]
toml_path = specific_config_toml_path(
@@ -50,11 +50,11 @@ def test_eval_parametric_umap_model(
vak.eval.parametric_umap.eval_parametric_umap_model(
model_config=cfg.eval.model.asdict(),
dataset_config=cfg.eval.dataset.asdict(),
+ trainer_config=cfg.eval.trainer.asdict(),
checkpoint_path=cfg.eval.checkpoint_path,
output_dir=cfg.eval.output_dir,
batch_size=cfg.eval.batch_size,
num_workers=cfg.eval.num_workers,
- device=cfg.eval.device,
)
assert_eval_output_matches_expected(cfg.eval.model.name, output_dir)
@@ -70,7 +70,7 @@ def test_eval_frame_classification_model_raises_file_not_found(
path_option_to_change,
specific_config_toml_path,
tmp_path,
- device
+ trainer_table
):
"""Test that :func:`vak.eval.parametric_umap.eval_parametric_umap_model`
raises FileNotFoundError when expected"""
@@ -81,7 +81,7 @@ def test_eval_frame_classification_model_raises_file_not_found(
keys_to_change = [
{"table": "eval", "key": "output_dir", "value": str(output_dir)},
- {"table": "eval", "key": "device", "value": device},
+ {"table": "eval", "key": "trainer", "value": trainer_table},
path_option_to_change,
]
@@ -98,11 +98,11 @@ def test_eval_frame_classification_model_raises_file_not_found(
vak.eval.parametric_umap.eval_parametric_umap_model(
model_config=cfg.eval.model.asdict(),
dataset_config=cfg.eval.dataset.asdict(),
+ trainer_config=cfg.eval.trainer.asdict(),
checkpoint_path=cfg.eval.checkpoint_path,
output_dir=cfg.eval.output_dir,
batch_size=cfg.eval.batch_size,
num_workers=cfg.eval.num_workers,
- device=cfg.eval.device,
)
@@ -116,14 +116,14 @@ def test_eval_frame_classification_model_raises_file_not_found(
def test_eval_frame_classification_model_raises_not_a_directory(
path_option_to_change,
specific_config_toml_path,
- device,
+ trainer_table,
tmp_path,
):
"""Test that :func:`vak.eval.parametric_umap.eval_parametric_umap_model`
raises NotADirectoryError when expected"""
keys_to_change = [
path_option_to_change,
- {"table": "eval", "key": "device", "value": device},
+ {"table": "eval", "key": "trainer", "value": trainer_table},
]
if path_option_to_change["key"] != "output_dir":
@@ -150,9 +150,9 @@ def test_eval_frame_classification_model_raises_not_a_directory(
vak.eval.parametric_umap.eval_parametric_umap_model(
model_config=cfg.eval.model.asdict(),
dataset_config=cfg.eval.dataset.asdict(),
+ trainer_config=cfg.eval.trainer.asdict(),
checkpoint_path=cfg.eval.checkpoint_path,
output_dir=cfg.eval.output_dir,
batch_size=cfg.eval.batch_size,
num_workers=cfg.eval.num_workers,
- device=cfg.eval.device,
)
diff --git a/tests/test_learncurve/test_frame_classification.py b/tests/test_learncurve/test_frame_classification.py
index c88f4ebf4..ddaa99f6a 100644
--- a/tests/test_learncurve/test_frame_classification.py
+++ b/tests/test_learncurve/test_frame_classification.py
@@ -51,8 +51,8 @@ def assert_learncurve_output_matches_expected(cfg, model_name, results_path):
]
)
def test_learning_curve_for_frame_classification_model(
- model_name, audio_format, annot_format, specific_config_toml_path, tmp_path, device):
- keys_to_change = {"table": "learncurve", "key": "device", "value": device}
+ model_name, audio_format, annot_format, specific_config_toml_path, tmp_path, trainer_table):
+ keys_to_change = {"table": "learncurve", "key": "trainer", "value": trainer_table}
toml_path = specific_config_toml_path(
config_type="learncurve",
@@ -69,6 +69,7 @@ def test_learning_curve_for_frame_classification_model(
vak.learncurve.frame_classification.learning_curve_for_frame_classification_model(
model_config=cfg.learncurve.model.asdict(),
dataset_config=cfg.learncurve.dataset.asdict(),
+ trainer_config=cfg.learncurve.trainer.asdict(),
batch_size=cfg.learncurve.batch_size,
num_epochs=cfg.learncurve.num_epochs,
num_workers=cfg.learncurve.num_workers,
@@ -79,7 +80,6 @@ def test_learning_curve_for_frame_classification_model(
val_step=cfg.learncurve.val_step,
ckpt_step=cfg.learncurve.ckpt_step,
patience=cfg.learncurve.patience,
- device=cfg.learncurve.device,
)
assert_learncurve_output_matches_expected(cfg, cfg.learncurve.model.name, results_path)
@@ -94,13 +94,13 @@ def test_learning_curve_for_frame_classification_model(
)
def test_learncurve_raises_not_a_directory(dir_option_to_change,
specific_config_toml_path,
- tmp_path, device):
+ tmp_path, trainer_table):
"""Test that core.learncurve.learning_curve raises NotADirectoryError
when the following directories do not exist:
results_path, previous_run_path, dataset_path
"""
keys_to_change = [
- {"table": "learncurve", "key": "device", "value": device},
+ {"table": "learncurve", "key": "trainer", "value": trainer_table},
dir_option_to_change
]
toml_path = specific_config_toml_path(
@@ -118,6 +118,7 @@ def test_learncurve_raises_not_a_directory(dir_option_to_change,
vak.learncurve.frame_classification.learning_curve_for_frame_classification_model(
model_config=cfg.learncurve.model.asdict(),
dataset_config=cfg.learncurve.dataset.asdict(),
+ trainer_config=cfg.learncurve.trainer.asdict(),
batch_size=cfg.learncurve.batch_size,
num_epochs=cfg.learncurve.num_epochs,
num_workers=cfg.learncurve.num_workers,
@@ -128,5 +129,4 @@ def test_learncurve_raises_not_a_directory(dir_option_to_change,
val_step=cfg.learncurve.val_step,
ckpt_step=cfg.learncurve.ckpt_step,
patience=cfg.learncurve.patience,
- device=cfg.learncurve.device,
)
diff --git a/tests/test_models/test_tweetynet.py b/tests/test_models/test_tweetynet.py
index 6c06c2029..ce219ea9a 100644
--- a/tests/test_models/test_tweetynet.py
+++ b/tests/test_models/test_tweetynet.py
@@ -1,7 +1,7 @@
import itertools
import pytest
-import pytorch_lightning as lightning
+import lightning
import vak
@@ -37,7 +37,7 @@ def test_model_is_decorated(self):
assert issubclass(vak.models.TweetyNet,
vak.models.base.Model)
assert issubclass(vak.models.TweetyNet,
- lightning.LightningModule)
+ lightning.pytorch.LightningModule)
@pytest.mark.parametrize(
'labelmap, input_shape',
diff --git a/tests/test_predict/test_frame_classification.py b/tests/test_predict/test_frame_classification.py
index 5b8902f9d..f0507a016 100644
--- a/tests/test_predict/test_frame_classification.py
+++ b/tests/test_predict/test_frame_classification.py
@@ -30,7 +30,7 @@ def test_predict_with_frame_classification_model(
save_net_outputs,
specific_config_toml_path,
tmp_path,
- device,
+ trainer_table,
):
output_dir = tmp_path.joinpath(
f"test_predict_{audio_format}_{spect_format}_{annot_format}"
@@ -39,7 +39,7 @@ def test_predict_with_frame_classification_model(
keys_to_change = [
{"table": "predict", "key": "output_dir", "value": str(output_dir)},
- {"table": "predict", "key": "device", "value": device},
+ {"table": "predict", "key": "trainer", "value": trainer_table},
{"table": "predict", "key": "save_net_outputs", "value": save_net_outputs},
]
toml_path = specific_config_toml_path(
@@ -54,12 +54,12 @@ def test_predict_with_frame_classification_model(
vak.predict.frame_classification.predict_with_frame_classification_model(
model_config=cfg.predict.model.asdict(),
dataset_config=cfg.predict.dataset.asdict(),
+ trainer_config=cfg.predict.trainer.asdict(),
checkpoint_path=cfg.predict.checkpoint_path,
labelmap_path=cfg.predict.labelmap_path,
num_workers=cfg.predict.num_workers,
timebins_key=cfg.prep.spect_params.timebins_key,
spect_scaler_path=cfg.predict.spect_scaler_path,
- device=cfg.predict.device,
annot_csv_filename=cfg.predict.annot_csv_filename,
output_dir=cfg.predict.output_dir,
min_segment_dur=cfg.predict.min_segment_dur,
@@ -98,7 +98,7 @@ def test_predict_with_frame_classification_model_raises_file_not_found(
path_option_to_change,
specific_config_toml_path,
tmp_path,
- device
+ trainer_table
):
"""Test that core.eval raises FileNotFoundError
when `dataset_path` does not exist."""
@@ -109,7 +109,7 @@ def test_predict_with_frame_classification_model_raises_file_not_found(
keys_to_change = [
{"table": "predict", "key": "output_dir", "value": str(output_dir)},
- {"table": "predict", "key": "device", "value": device},
+ {"table": "predict", "key": "trainer", "value": trainer_table},
path_option_to_change,
]
toml_path = specific_config_toml_path(
@@ -125,12 +125,12 @@ def test_predict_with_frame_classification_model_raises_file_not_found(
vak.predict.frame_classification.predict_with_frame_classification_model(
model_config=cfg.predict.model.asdict(),
dataset_config=cfg.predict.dataset.asdict(),
+ trainer_config=cfg.predict.trainer.asdict(),
checkpoint_path=cfg.predict.checkpoint_path,
labelmap_path=cfg.predict.labelmap_path,
num_workers=cfg.predict.num_workers,
timebins_key=cfg.prep.spect_params.timebins_key,
spect_scaler_path=cfg.predict.spect_scaler_path,
- device=cfg.predict.device,
annot_csv_filename=cfg.predict.annot_csv_filename,
output_dir=cfg.predict.output_dir,
min_segment_dur=cfg.predict.min_segment_dur,
@@ -149,7 +149,7 @@ def test_predict_with_frame_classification_model_raises_file_not_found(
def test_predict_with_frame_classification_model_raises_not_a_directory(
path_option_to_change,
specific_config_toml_path,
- device,
+ trainer_table,
tmp_path,
):
"""Test that core.eval raises NotADirectory
@@ -157,7 +157,7 @@ def test_predict_with_frame_classification_model_raises_not_a_directory(
"""
keys_to_change = [
path_option_to_change,
- {"table": "predict", "key": "device", "value": device},
+ {"table": "predict", "key": "trainer", "value": trainer_table},
]
if path_option_to_change["key"] != "output_dir":
@@ -184,12 +184,12 @@ def test_predict_with_frame_classification_model_raises_not_a_directory(
vak.predict.frame_classification.predict_with_frame_classification_model(
model_config=cfg.predict.model.asdict(),
dataset_config=cfg.predict.dataset.asdict(),
+ trainer_config=cfg.predict.trainer.asdict(),
checkpoint_path=cfg.predict.checkpoint_path,
labelmap_path=cfg.predict.labelmap_path,
num_workers=cfg.predict.num_workers,
timebins_key=cfg.prep.spect_params.timebins_key,
spect_scaler_path=cfg.predict.spect_scaler_path,
- device=cfg.predict.device,
annot_csv_filename=cfg.predict.annot_csv_filename,
output_dir=cfg.predict.output_dir,
min_segment_dur=cfg.predict.min_segment_dur,
diff --git a/tests/test_predict/test_predict.py b/tests/test_predict/test_predict.py
index 820613ed4..f7a7f0f7a 100644
--- a/tests/test_predict/test_predict.py
+++ b/tests/test_predict/test_predict.py
@@ -28,7 +28,7 @@ def test_predict(
keys_to_change = [
{"table": "predict", "key": "output_dir", "value": str(output_dir)},
- {"table": "predict", "key": "device", "value": 'cpu'},
+ {"table": "predict", "key": "trainer", "value": {"accelerator": "cpu", "devices": 1}},
]
toml_path = specific_config_toml_path(
@@ -47,12 +47,12 @@ def test_predict(
vak.predict.predict(
model_config=cfg.predict.model.asdict(),
dataset_config=cfg.predict.dataset.asdict(),
+ trainer_config=cfg.predict.trainer.asdict(),
checkpoint_path=cfg.predict.checkpoint_path,
labelmap_path=cfg.predict.labelmap_path,
num_workers=cfg.predict.num_workers,
timebins_key=cfg.prep.spect_params.timebins_key,
spect_scaler_path=cfg.predict.spect_scaler_path,
- device=cfg.predict.device,
annot_csv_filename=cfg.predict.annot_csv_filename,
output_dir=cfg.predict.output_dir,
min_segment_dur=cfg.predict.min_segment_dur,
diff --git a/tests/test_prep/test_frame_classification/test_learncurve.py b/tests/test_prep/test_frame_classification/test_learncurve.py
index ef3105144..0030e01d0 100644
--- a/tests/test_prep/test_frame_classification/test_learncurve.py
+++ b/tests/test_prep/test_frame_classification/test_learncurve.py
@@ -18,7 +18,7 @@
]
)
def test_make_index_vectors_for_each_subsets(
- model_name, audio_format, annot_format, input_type, specific_config_toml_path, device, tmp_path,
+ model_name, audio_format, annot_format, input_type, specific_config_toml_path, trainer_table, tmp_path,
):
root_results_dir = tmp_path.joinpath("tmp_root_results_dir")
root_results_dir.mkdir()
@@ -130,7 +130,7 @@ def test_make_index_vectors_for_each_subsets(
]
)
def test_make_subsets_from_dataset_df(
- model_name, audio_format, annot_format, input_type, specific_config_toml_path, device, tmp_path,
+ model_name, audio_format, annot_format, input_type, specific_config_toml_path, trainer_table, tmp_path,
):
root_results_dir = tmp_path.joinpath("tmp_root_results_dir")
root_results_dir.mkdir()
diff --git a/tests/test_train/test_frame_classification.py b/tests/test_train/test_frame_classification.py
index f4e50ef46..e9d06aa98 100644
--- a/tests/test_train/test_frame_classification.py
+++ b/tests/test_train/test_frame_classification.py
@@ -42,12 +42,12 @@ def assert_train_output_matches_expected(cfg: vak.config.config.Config, model_na
],
)
def test_train_frame_classification_model(
- model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, device
+ model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, trainer_table
):
results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path)
results_path.mkdir()
keys_to_change = [
- {"table": "train", "key": "device", "value": device},
+ {"table": "train", "key": "trainer", "value": trainer_table},
{"table": "train", "key": "root_results_dir", "value": str(results_path)}
]
toml_path = specific_config_toml_path(
@@ -63,6 +63,7 @@ def test_train_frame_classification_model(
vak.train.frame_classification.train_frame_classification_model(
model_config=cfg.train.model.asdict(),
dataset_config=cfg.train.dataset.asdict(),
+ trainer_config=cfg.train.trainer.asdict(),
batch_size=cfg.train.batch_size,
num_epochs=cfg.train.num_epochs,
num_workers=cfg.train.num_workers,
@@ -74,7 +75,6 @@ def test_train_frame_classification_model(
val_step=cfg.train.val_step,
ckpt_step=cfg.train.ckpt_step,
patience=cfg.train.patience,
- device=cfg.train.device,
)
assert_train_output_matches_expected(cfg, cfg.train.model.name, results_path)
@@ -89,12 +89,12 @@ def test_train_frame_classification_model(
],
)
def test_continue_training(
- model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, device
+ model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, trainer_table
):
results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path)
results_path.mkdir()
keys_to_change = [
- {"table": "train", "key": "device", "value": device},
+ {"table": "train", "key": "trainer", "value": trainer_table},
{"table": "train", "key": "root_results_dir", "value": str(results_path)}
]
toml_path = specific_config_toml_path(
@@ -110,6 +110,7 @@ def test_continue_training(
vak.train.frame_classification.train_frame_classification_model(
model_config=cfg.train.model.asdict(),
dataset_config=cfg.train.dataset.asdict(),
+ trainer_config=cfg.train.trainer.asdict(),
batch_size=cfg.train.batch_size,
num_epochs=cfg.train.num_epochs,
num_workers=cfg.train.num_workers,
@@ -121,7 +122,6 @@ def test_continue_training(
val_step=cfg.train.val_step,
ckpt_step=cfg.train.ckpt_step,
patience=cfg.train.patience,
- device=cfg.train.device,
)
assert_train_output_matches_expected(cfg, cfg.train.model.name, results_path)
@@ -135,14 +135,14 @@ def test_continue_training(
]
)
def test_train_raises_file_not_found(
- path_option_to_change, specific_config_toml_path, tmp_path, device
+ path_option_to_change, specific_config_toml_path, tmp_path, trainer_table
):
"""Test that pre-conditions in `vak.train` raise FileNotFoundError
when one of the following does not exist:
checkpoint_path, dataset_path, spect_scaler_path
"""
keys_to_change = [
- {"table": "train", "key": "device", "value": device},
+ {"table": "train", "key": "trainer", "value": trainer_table},
path_option_to_change
]
toml_path = specific_config_toml_path(
@@ -161,6 +161,7 @@ def test_train_raises_file_not_found(
vak.train.frame_classification.train_frame_classification_model(
model_config=cfg.train.model.asdict(),
dataset_config=cfg.train.dataset.asdict(),
+ trainer_config=cfg.train.trainer.asdict(),
batch_size=cfg.train.batch_size,
num_epochs=cfg.train.num_epochs,
num_workers=cfg.train.num_workers,
@@ -172,7 +173,6 @@ def test_train_raises_file_not_found(
val_step=cfg.train.val_step,
ckpt_step=cfg.train.ckpt_step,
patience=cfg.train.patience,
- device=cfg.train.device,
)
@@ -184,14 +184,14 @@ def test_train_raises_file_not_found(
]
)
def test_train_raises_not_a_directory(
- path_option_to_change, specific_config_toml_path, device, tmp_path
+ path_option_to_change, specific_config_toml_path, trainer_table, tmp_path
):
"""Test that core.train raises NotADirectory
when directory does not exist
"""
keys_to_change = [
path_option_to_change,
- {"table": "train", "key": "device", "value": device},
+ {"table": "train", "key": "trainer", "value": trainer_table},
]
toml_path = specific_config_toml_path(
@@ -211,6 +211,7 @@ def test_train_raises_not_a_directory(
vak.train.frame_classification.train_frame_classification_model(
model_config=cfg.train.model.asdict(),
dataset_config=cfg.train.dataset.asdict(),
+ trainer_config=cfg.train.trainer.asdict(),
batch_size=cfg.train.batch_size,
num_epochs=cfg.train.num_epochs,
num_workers=cfg.train.num_workers,
@@ -222,5 +223,4 @@ def test_train_raises_not_a_directory(
val_step=cfg.train.val_step,
ckpt_step=cfg.train.ckpt_step,
patience=cfg.train.patience,
- device=cfg.train.device,
)
diff --git a/tests/test_train/test_parametric_umap.py b/tests/test_train/test_parametric_umap.py
index 509078c21..247009a8d 100644
--- a/tests/test_train/test_parametric_umap.py
+++ b/tests/test_train/test_parametric_umap.py
@@ -35,12 +35,12 @@ def assert_train_output_matches_expected(cfg: vak.config.config.Config, model_na
)
def test_train_parametric_umap_model(
model_name, audio_format, spect_format, annot_format,
- specific_config_toml_path, tmp_path, device
+ specific_config_toml_path, tmp_path, trainer_table
):
results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path)
results_path.mkdir()
keys_to_change = [
- {"table": "train", "key": "device", "value": device},
+ {"table": "train", "key": "trainer", "value": trainer_table},
{"table": "train", "key": "root_results_dir", "value": str(results_path)}
]
toml_path = specific_config_toml_path(
@@ -56,6 +56,7 @@ def test_train_parametric_umap_model(
vak.train.parametric_umap.train_parametric_umap_model(
model_config=cfg.train.model.asdict(),
dataset_config=cfg.train.dataset.asdict(),
+ trainer_config=cfg.train.trainer.asdict(),
batch_size=cfg.train.batch_size,
num_epochs=cfg.train.num_epochs,
num_workers=cfg.train.num_workers,
@@ -64,7 +65,6 @@ def test_train_parametric_umap_model(
shuffle=cfg.train.shuffle,
val_step=cfg.train.val_step,
ckpt_step=cfg.train.ckpt_step,
- device=cfg.train.device,
)
assert_train_output_matches_expected(cfg, cfg.train.model.name, results_path)
@@ -77,14 +77,14 @@ def test_train_parametric_umap_model(
]
)
def test_train_parametric_umap_model_raises_file_not_found(
- path_option_to_change, specific_config_toml_path, tmp_path, device
+ path_option_to_change, specific_config_toml_path, tmp_path, trainer_table
):
"""Test that pre-conditions in :func:`vak.train.parametric_umap.train_parametric_umap_model`
raise FileNotFoundError when one of the following does not exist:
checkpoint_path, dataset_path
"""
keys_to_change = [
- {"table": "train", "key": "device", "value": device},
+ {"table": "train", "key": "trainer", "value": trainer_table},
path_option_to_change
]
toml_path = specific_config_toml_path(
@@ -103,6 +103,7 @@ def test_train_parametric_umap_model_raises_file_not_found(
vak.train.parametric_umap.train_parametric_umap_model(
model_config=cfg.train.model.asdict(),
dataset_config=cfg.train.dataset.asdict(),
+ trainer_config=cfg.train.trainer.asdict(),
batch_size=cfg.train.batch_size,
num_epochs=cfg.train.num_epochs,
num_workers=cfg.train.num_workers,
@@ -111,7 +112,6 @@ def test_train_parametric_umap_model_raises_file_not_found(
shuffle=cfg.train.shuffle,
val_step=cfg.train.val_step,
ckpt_step=cfg.train.ckpt_step,
- device=cfg.train.device,
)
@@ -123,14 +123,14 @@ def test_train_parametric_umap_model_raises_file_not_found(
]
)
def test_train_parametric_umap_model_raises_not_a_directory(
- path_option_to_change, specific_config_toml_path, device, tmp_path
+ path_option_to_change, specific_config_toml_path, trainer_table, tmp_path
):
"""Test that core.train raises NotADirectory
when directory does not exist
"""
keys_to_change = [
path_option_to_change,
- {"table": "train", "key": "device", "value": device},
+ {"table": "train", "key": "trainer", "value": trainer_table},
]
toml_path = specific_config_toml_path(
@@ -151,6 +151,7 @@ def test_train_parametric_umap_model_raises_not_a_directory(
vak.train.parametric_umap.train_parametric_umap_model(
model_config=model_config,
dataset_config=cfg.train.dataset.asdict(),
+ trainer_config=cfg.train.trainer.asdict(),
batch_size=cfg.train.batch_size,
num_epochs=cfg.train.num_epochs,
num_workers=cfg.train.num_workers,
@@ -159,5 +160,4 @@ def test_train_parametric_umap_model_raises_not_a_directory(
shuffle=cfg.train.shuffle,
val_step=cfg.train.val_step,
ckpt_step=cfg.train.ckpt_step,
- device=cfg.train.device,
)
diff --git a/tests/test_train/test_train.py b/tests/test_train/test_train.py
index b9038007e..11886db68 100644
--- a/tests/test_train/test_train.py
+++ b/tests/test_train/test_train.py
@@ -35,7 +35,7 @@ def test_train(
"key": "root_results_dir",
"value": str(root_results_dir),
},
- {"table": "train", "key": "device", "value": 'cpu'},
+ {"table": "train", "key": "trainer", "value": {"accelerator": "cpu", "devices": 1}},
]
toml_path = specific_config_toml_path(
@@ -55,6 +55,7 @@ def test_train(
vak.train.train(
model_config=cfg.train.model.asdict(),
dataset_config=cfg.train.dataset.asdict(),
+ trainer_config=cfg.train.trainer.asdict(),
batch_size=cfg.train.batch_size,
num_epochs=cfg.train.num_epochs,
num_workers=cfg.train.num_workers,
@@ -66,6 +67,5 @@ def test_train(
val_step=cfg.train.val_step,
ckpt_step=cfg.train.ckpt_step,
patience=cfg.train.patience,
- device=cfg.train.device,
)
assert mock_train_function.called