diff --git a/doc/api/index.rst b/doc/api/index.rst index 03048719f..4b3d2c4d2 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -130,14 +130,16 @@ and dataclasses that represent tables from those files. :recursive: config.config + config.dataset config.eval config.learncurve + config.load config.model - config.parse config.predict config.prep config.spect_params config.train + config.trainer config.validators Datasets @@ -265,10 +267,10 @@ used by multiple other modules. :template: module.rst :recursive: + common.accelerator common.annotation common.constants common.converters - common.device common.files common.labels common.learncurve diff --git a/doc/conf.py b/doc/conf.py index 2c3099705..f7b967bf6 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -107,8 +107,8 @@ "announcement": """ 🚧 vak version 1.0.0 is in development! 🚧 📣 Test out the alpha release: pip install vak==1.0.0a3. 📣 - For more info, please see - this forum post. + For more info, please see + this forum post. """, "sidebar_hide_name": True, "light_css_variables": { @@ -217,7 +217,8 @@ "python": ("https://docs.python.org/3/", None), "numpy": ("https://numpy.org/doc/stable/", None), "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), - "pytorch": ("https://pytorch.org/docs/stable/", None) + "pytorch": ("https://pytorch.org/docs/stable/", None), + "lightning": ("https://lightning.ai/docs/pytorch/stable/", None), } # -- Options for todo extension ---------------------------------------------- diff --git a/doc/toml/gy6or6_eval.toml b/doc/toml/gy6or6_eval.toml index 71355ed28..41d25a1c7 100644 --- a/doc/toml/gy6or6_eval.toml +++ b/doc/toml/gy6or6_eval.toml @@ -43,7 +43,7 @@ batch_size = 11 # num_workers: number of workers to use when loading data with multiprocessing num_workers = 16 # device: name of device to run model on, one of "cuda", "cpu" -device = "cuda" + # output_dir: directory where output should be saved, as a sub-directory within `output_dir` output_dir = "/PATH/TO/FOLDER/results/eval" # dataset_path : path to dataset created by prep @@ -72,3 +72,10 @@ window_size = 176 # Note we do not specify any options for the model, and just use the defaults # We need to put this table here though so we know which model we are using [vak.eval.model.TweetyNet] + +# this sub-table configures the `lightning.pytorch.Trainer` +[vak.eval.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] diff --git a/doc/toml/gy6or6_predict.toml b/doc/toml/gy6or6_predict.toml index c4c89ef73..4bbfa1dd2 100644 --- a/doc/toml/gy6or6_predict.toml +++ b/doc/toml/gy6or6_predict.toml @@ -39,7 +39,7 @@ batch_size = 1 # num_workers: number of workers to use when loading data with multiprocessing num_workers = 4 # device: name of device to run model on, one of "cuda", "cpu" -device = "cuda" + # output_dir: directory where output should be saved, as a sub-directory within `output_dir` output_dir = "/PATH/TO/FOLDER/results/predict" # annot_csv_filename @@ -67,3 +67,10 @@ window_size = 176 # Note we do not specify any options for the network, and just use the defaults # We need to put this table here though, to indicate which model we are using. [vak.predict.model.TweetyNet] + +# this sub-table configures the `lightning.pytorch.Trainer` +[vak.predict.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] diff --git a/doc/toml/gy6or6_train.toml b/doc/toml/gy6or6_train.toml index 68202f796..6c802b2ec 100644 --- a/doc/toml/gy6or6_train.toml +++ b/doc/toml/gy6or6_train.toml @@ -53,7 +53,7 @@ patience = 4 # num_workers: number of workers to use when loading data with multiprocessing num_workers = 4 # device: name of device to run model on, one of "cuda", "cpu" -device = "cuda" + # dataset_path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it # dataset.params = parameters used for datasets @@ -73,3 +73,10 @@ lr = 0.001 [vak.train.model.TweetyNet.network] # hidden_size: the number of elements in the hidden state in the recurrent layer of the network hidden_size = 256 + +# this sub-table configures the `lightning.pytorch.Trainer` +[vak.train.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] diff --git a/pyproject.toml b/pyproject.toml index f0895a996..61efe6fd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "dask[dataframe] >=2.10.1", "evfuncs >=0.3.4", "joblib >=0.14.1", - "pytorch-lightning >=2.0.7", + "lightning >=2.0.7", "matplotlib >=3.3.3", "numpy >=1.18.1", "pynndescent >=0.5.10", diff --git a/src/vak/cli/eval.py b/src/vak/cli/eval.py index 29bee65a5..a80245464 100644 --- a/src/vak/cli/eval.py +++ b/src/vak/cli/eval.py @@ -52,12 +52,12 @@ def eval(toml_path: str | pathlib.Path) -> None: eval_module.eval( model_config=cfg.eval.model.asdict(), dataset_config=cfg.eval.dataset.asdict(), + trainer_config=cfg.eval.trainer.asdict(), checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, num_workers=cfg.eval.num_workers, batch_size=cfg.eval.batch_size, spect_scaler_path=cfg.eval.spect_scaler_path, - device=cfg.eval.device, post_tfm_kwargs=cfg.eval.post_tfm_kwargs, ) diff --git a/src/vak/cli/learncurve.py b/src/vak/cli/learncurve.py index 2decc5cd8..c8f407fe0 100644 --- a/src/vak/cli/learncurve.py +++ b/src/vak/cli/learncurve.py @@ -55,6 +55,7 @@ def learning_curve(toml_path): learncurve.learning_curve( model_config=cfg.learncurve.model.asdict(), dataset_config=cfg.learncurve.dataset.asdict(), + trainer_config=cfg.learncurve.trainer.asdict(), batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, num_workers=cfg.learncurve.num_workers, @@ -65,5 +66,4 @@ def learning_curve(toml_path): val_step=cfg.learncurve.val_step, ckpt_step=cfg.learncurve.ckpt_step, patience=cfg.learncurve.patience, - device=cfg.learncurve.device, ) diff --git a/src/vak/cli/predict.py b/src/vak/cli/predict.py index 01c0e2612..474b2b8ca 100644 --- a/src/vak/cli/predict.py +++ b/src/vak/cli/predict.py @@ -45,12 +45,12 @@ def predict(toml_path): predict_module.predict( model_config=cfg.predict.model.asdict(), dataset_config=cfg.predict.dataset.asdict(), + trainer_config=cfg.predict.trainer.asdict(), checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, - device=cfg.predict.device, annot_csv_filename=cfg.predict.annot_csv_filename, output_dir=cfg.predict.output_dir, min_segment_dur=cfg.predict.min_segment_dur, diff --git a/src/vak/cli/train.py b/src/vak/cli/train.py index c63096ca2..88c2fb6d9 100644 --- a/src/vak/cli/train.py +++ b/src/vak/cli/train.py @@ -55,6 +55,7 @@ def train(toml_path): train_module.train( model_config=cfg.train.model.asdict(), dataset_config=cfg.train.dataset.asdict(), + trainer_config=cfg.train.trainer.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -66,5 +67,4 @@ def train(toml_path): val_step=cfg.train.val_step, ckpt_step=cfg.train.ckpt_step, patience=cfg.train.patience, - device=cfg.train.device, ) diff --git a/src/vak/common/__init__.py b/src/vak/common/__init__.py index 777bd9afb..11c4e330d 100644 --- a/src/vak/common/__init__.py +++ b/src/vak/common/__init__.py @@ -9,10 +9,10 @@ """ from . import ( + accelerator, annotation, constants, converters, - device, files, labels, learncurve, @@ -30,7 +30,7 @@ "annotation", "constants", "converters", - "device", + "accelerator", "files", "labels", "learncurve", diff --git a/src/vak/common/accelerator.py b/src/vak/common/accelerator.py new file mode 100644 index 000000000..3f3e3eee2 --- /dev/null +++ b/src/vak/common/accelerator.py @@ -0,0 +1,16 @@ +import torch + + +def get_default() -> str: + """Get default `accelerator` for :class:`lightning.pytorch.Trainer`. + + Returns + ------- + accelerator : str + Will be ``'gpu'`` if :func:`torch.cuda.is_available` + is ``True``, and ``'cpu'`` if not. + """ + if torch.cuda.is_available(): + return "gpu" + else: + return "cpu" diff --git a/src/vak/common/device.py b/src/vak/common/device.py deleted file mode 100644 index 5e7d506e5..000000000 --- a/src/vak/common/device.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch - - -def get_default(): - """get default device for torch. - - Returns - ------- - device : str - 'cuda' if torch.cuda.is_available() is True, - and returns 'cpu' otherwise. - """ - if torch.cuda.is_available(): - return "cuda" - else: - return "cpu" diff --git a/src/vak/common/trainer.py b/src/vak/common/trainer.py index 7a6400d97..cd12d03d4 100644 --- a/src/vak/common/trainer.py +++ b/src/vak/common/trainer.py @@ -2,7 +2,7 @@ import pathlib -import pytorch_lightning as lightning +import lightning def get_default_train_callbacks( @@ -10,7 +10,7 @@ def get_default_train_callbacks( ckpt_step: int, patience: int, ): - ckpt_callback = lightning.callbacks.ModelCheckpoint( + ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint( dirpath=ckpt_root, filename="checkpoint", every_n_train_steps=ckpt_step, @@ -20,7 +20,7 @@ def get_default_train_callbacks( ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint" ckpt_callback.FILE_EXTENSION = ".pt" - val_ckpt_callback = lightning.callbacks.ModelCheckpoint( + val_ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint( monitor="val_acc", dirpath=ckpt_root, save_top_k=1, @@ -31,7 +31,7 @@ def get_default_train_callbacks( ) val_ckpt_callback.FILE_EXTENSION = ".pt" - early_stopping = lightning.callbacks.EarlyStopping( + early_stopping = lightning.pytorch.callbacks.EarlyStopping( mode="max", monitor="val_acc", patience=patience, @@ -42,33 +42,47 @@ def get_default_train_callbacks( def get_default_trainer( + accelerator: str, + devices: int | list[int], max_steps: int, log_save_dir: str | pathlib.Path, val_step: int, default_callback_kwargs: dict | None = None, - device: str = "cuda", -) -> lightning.Trainer: - """Returns an instance of ``lightning.Trainer`` +) -> lightning.pytorch.Trainer: + """Returns an instance of :class:`lightning.pytorch.Trainer` with a default set of callbacks. - Used by ``vak.core`` functions.""" + + Used by :func:`vak.train.frame_classification`. + The default set of callbacks is provided by + :func:`get_default_train_callbacks`. + + Parameters + ---------- + accelerator : str + devices : int, list of int + max_steps : int + log_save_dir : str, pathlib.Path + val_step : int + default_callback_kwargs : dict, optional + + Returns + ------- + trainer : lightning.pytorch.Trainer + + """ if default_callback_kwargs: callbacks = get_default_train_callbacks(**default_callback_kwargs) else: callbacks = None - # TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691 - if device == "cuda": - accelerator = "gpu" - else: - accelerator = "auto" - - logger = lightning.loggers.TensorBoardLogger(save_dir=log_save_dir) + logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=log_save_dir) - trainer = lightning.Trainer( + trainer = lightning.pytorch.Trainer( + accelerator=accelerator, + devices=devices, callbacks=callbacks, val_check_interval=val_step, max_steps=max_steps, - accelerator=accelerator, logger=logger, ) return trainer diff --git a/src/vak/config/__init__.py b/src/vak/config/__init__.py index 056c0ef12..e0184c522 100644 --- a/src/vak/config/__init__.py +++ b/src/vak/config/__init__.py @@ -11,6 +11,7 @@ prep, spect_params, train, + trainer, validators, ) from .config import Config @@ -22,6 +23,8 @@ from .prep import PrepConfig from .spect_params import SpectParamsConfig from .train import TrainConfig +from .trainer import TrainerConfig + __all__ = [ "config", @@ -34,6 +37,7 @@ "prep", "spect_params", "train", + "trainer", "validators", "Config", "DatasetConfig", @@ -44,4 +48,5 @@ "PrepConfig", "SpectParamsConfig", "TrainConfig", + "TrainerConfig", ] diff --git a/src/vak/config/eval.py b/src/vak/config/eval.py index 3012da649..e524efdca 100644 --- a/src/vak/config/eval.py +++ b/src/vak/config/eval.py @@ -7,10 +7,10 @@ from attrs import converters, define, field, validators from attrs.validators import instance_of -from ..common import device from ..common.converters import expanded_user_path from .dataset import DatasetConfig from .model import ModelConfig +from .trainer import TrainerConfig def convert_post_tfm_kwargs(post_tfm_kwargs: dict) -> dict: @@ -77,6 +77,7 @@ def are_valid_post_tfm_kwargs(instance, attribute, value): "dataset", "output_dir", "model", + "trainer", ) @@ -86,29 +87,29 @@ class EvalConfig: Attributes ---------- - dataset : vak.config.DatasetConfig - The dataset to use: the path to it, - and optionally a path to a file representing splits, - and the name, if it is a built-in dataset. - Must be an instance of :class:`vak.config.DatasetConfig`. checkpoint_path : str path to directory with checkpoint files saved by Torch, to reload model output_dir : str Path to location where .csv files with evaluation metrics should be saved. - labelmap_path : str - path to 'labelmap.json' file. model : vak.config.ModelConfig The model to use: its name, and the parameters to configure it. Must be an instance of :class:`vak.config.ModelConfig` batch_size : int number of samples per batch presented to models during training. + dataset : vak.config.DatasetConfig + The dataset to use: the path to it, + and optionally a path to a file representing splits, + and the name, if it is a built-in dataset. + Must be an instance of :class:`vak.config.DatasetConfig`. + trainer : vak.config.TrainerConfig + Configuration for :class:`lightning.pytorch.Trainer`. + Must be an instance of :class:`vak.config.TrainerConfig`. num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2. - device : str - Device on which to work with model + data. - Defaults to 'cuda' if torch.cuda.is_available is True. + labelmap_path : str + path to 'labelmap.json' file. spect_scaler_path : str path to a saved SpectScaler object used to normalize spectrograms. If spectrograms were normalized and this is not provided, will give @@ -140,6 +141,9 @@ class EvalConfig: dataset: DatasetConfig = field( validator=instance_of(DatasetConfig), ) + trainer: TrainerConfig = field( + validator=instance_of(TrainerConfig), + ) # "optional" but actually required for frame classification models # TODO: check model family in __post_init__ and raise ValueError if labelmap @@ -161,7 +165,6 @@ class EvalConfig: # optional, data loader num_workers = field(validator=instance_of(int), default=2) - device = field(validator=instance_of(str), default=device.get_default()) @classmethod def from_config_dict(cls, config_dict: dict) -> EvalConfig: @@ -186,4 +189,5 @@ def from_config_dict(cls, config_dict: dict) -> EvalConfig: config_dict["model"] = ModelConfig.from_config_dict( config_dict["model"] ) + config_dict["trainer"] = TrainerConfig(**config_dict["trainer"]) return cls(**config_dict) diff --git a/src/vak/config/learncurve.py b/src/vak/config/learncurve.py index fdf2b883a..bb564db72 100644 --- a/src/vak/config/learncurve.py +++ b/src/vak/config/learncurve.py @@ -8,8 +8,9 @@ from .eval import are_valid_post_tfm_kwargs, convert_post_tfm_kwargs from .model import ModelConfig from .train import TrainConfig +from .trainer import TrainerConfig -REQUIRED_KEYS = ("dataset", "model", "root_results_dir") +REQUIRED_KEYS = ("dataset", "model", "root_results_dir", "trainer",) @define @@ -22,31 +23,44 @@ class LearncurveConfig(TrainConfig): The model to use: its name, and the parameters to configure it. Must be an instance of :class:`vak.config.ModelConfig` + num_epochs : int + number of training epochs. One epoch = one iteration through the entire + training set. + batch_size : int + number of samples per batch presented to models during training. + root_results_dir : str + directory in which results will be created. + The vak.cli.train function will create + a subdirectory in this directory each time it runs. dataset : vak.config.DatasetConfig The dataset to use: the path to it, and optionally a path to a file representing splits, and the name, if it is a built-in dataset. Must be an instance of :class:`vak.config.DatasetConfig`. - num_epochs : int - number of training epochs. One epoch = one iteration through the entire - training set. + trainer : vak.config.TrainerConfig + Configuration for :class:`lightning.pytorch.Trainer`. + Must be an instance of :class:`vak.config.TrainerConfig`. + num_workers : int + Number of processes to use for parallel loading of data. + Argument to torch.DataLoader. + shuffle: bool + if True, shuffle training data before each epoch. Default is True. normalize_spectrograms : bool if True, use spect.utils.data.SpectScaler to normalize the spectrograms. Normalization is done by subtracting off the mean for each frequency bin of the training set and then dividing by the std for that frequency bin. This same normalization is then applied to validation + test data. + val_step : int + Step on which to estimate accuracy using validation set. + If val_step is n, then validation is carried out every time + the global step / n is a whole number, i.e., when val_step modulo the global step is 0. + Default is None, in which case no validation is done. ckpt_step : int step/epoch at which to save to checkpoint file. Default is None, in which case checkpoint is only saved at the last epoch. patience : int number of epochs to wait without the error dropping before stopping the training. Default is None, in which case training continues for num_epochs - save_only_single_checkpoint_file : bool - if True, save only one checkpoint file instead of separate files every time - we save. Default is True. - use_train_subsets_from_previous_run : bool - if True, use training subsets saved in a previous run. Default is False. - Requires setting previous_run_path option in config.toml file. post_tfm_kwargs : dict Keyword arguments to post-processing transform. If None, then no additional clean-up is applied @@ -61,7 +75,6 @@ class LearncurveConfig(TrainConfig): See the docstring of the transform for more details on these arguments and how they work. """ - post_tfm_kwargs = field( validator=validators.optional(are_valid_post_tfm_kwargs), converter=converters.optional(convert_post_tfm_kwargs), @@ -71,18 +84,18 @@ class LearncurveConfig(TrainConfig): # we over-ride this method from TrainConfig mainly so the docstring is correct. # TODO: can we do this by just over-writing `__doc__` for the method on this class? @classmethod - def from_config_dict(cls, config_dict: dict) -> "TrainConfig": + def from_config_dict(cls, config_dict: dict) -> LearncurveConfig: """Return :class:`LearncurveConfig` instance from a :class:`dict`. The :class:`dict` passed in should be the one found by loading a valid configuration toml file with :func:`vak.config.parse.from_toml_path`, - and then using key ``prep``, - i.e., ``LearncurveConfig.from_config_dict(config_dict['train'])``.""" + and then using key ``learncurve``, + i.e., ``LearncurveConfig.from_config_dict(config_dict['learncurve'])``.""" for required_key in REQUIRED_KEYS: if required_key not in config_dict: raise KeyError( - "The `[vak.train]` table in a configuration file requires " + "The `[vak.learncurve]` table in a configuration file requires " f"the option '{required_key}', but it was not found " "when loading the configuration file into a Python dictionary. " "Please check that the configuration file is formatted correctly." @@ -93,4 +106,5 @@ def from_config_dict(cls, config_dict: dict) -> "TrainConfig": config_dict["dataset"] = DatasetConfig.from_config_dict( config_dict["dataset"] ) + config_dict["trainer"] = TrainerConfig(**config_dict["trainer"]) return cls(**config_dict) diff --git a/src/vak/config/predict.py b/src/vak/config/predict.py index 8803d9317..ee63ee093 100644 --- a/src/vak/config/predict.py +++ b/src/vak/config/predict.py @@ -9,15 +9,17 @@ from attr.validators import instance_of from attrs import define, field -from ..common import device from ..common.converters import expanded_user_path from .dataset import DatasetConfig from .model import ModelConfig +from .trainer import TrainerConfig + REQUIRED_KEYS = ( "checkpoint_path", "dataset", "model", + "trainer", ) @@ -27,11 +29,6 @@ class PredictConfig: Attributes ---------- - dataset : vak.config.DatasetConfig - The dataset to use: the path to it, - and optionally a path to a file representing splits, - and the name, if it is a built-in dataset. - Must be an instance of :class:`vak.config.DatasetConfig`. checkpoint_path : str path to directory with checkpoint files saved by Torch, to reload model labelmap_path : str @@ -40,49 +37,54 @@ class PredictConfig: The model to use: its name, and the parameters to configure it. Must be an instance of :class:`vak.config.ModelConfig` - batch_size : int - number of samples per batch presented to models during training. - num_workers : int - Number of processes to use for parallel loading of data. - Argument to torch.DataLoader. Default is 2. - device : str - Device on which to work with model + data. - Defaults to 'cuda' if torch.cuda.is_available is True. - spect_scaler_path : str - path to a saved SpectScaler object used to normalize spectrograms. - If spectrograms were normalized and this is not provided, will give - incorrect results. - annot_csv_filename : str - name of .csv file containing predicted annotations. - Default is None, in which case the name of the dataset .csv - is used, with '.annot.csv' appended to it. - output_dir : str - path to location where .csv containing predicted annotation - should be saved. Defaults to current working directory. - min_segment_dur : float - minimum duration of segment, in seconds. If specified, then - any segment with a duration less than min_segment_dur is - removed from lbl_tb. Default is None, in which case no - segments are removed. - majority_vote : bool - if True, transform segments containing multiple labels - into segments with a single label by taking a "majority vote", - i.e. assign all time bins in the segment the most frequently - occurring label in the segment. This transform can only be - applied if the labelmap contains an 'unlabeled' label, - because unlabeled segments makes it possible to identify - the labeled segments. Default is False. + batch_size : int + number of samples per batch presented to models during training. + dataset : vak.config.DatasetConfig + The dataset to use: the path to it, + and optionally a path to a file representing splits, + and the name, if it is a built-in dataset. + Must be an instance of :class:`vak.config.DatasetConfig`. + trainer : vak.config.TrainerConfig + Configuration for :class:`lightning.pytorch.Trainer`. + Must be an instance of :class:`vak.config.TrainerConfig`. + num_workers : int + Number of processes to use for parallel loading of data. + Argument to torch.DataLoader. Default is 2. + spect_scaler_path : str + path to a saved SpectScaler object used to normalize spectrograms. + If spectrograms were normalized and this is not provided, will give + incorrect results. + annot_csv_filename : str + name of .csv file containing predicted annotations. + Default is None, in which case the name of the dataset .csv + is used, with '.annot.csv' appended to it. + output_dir : str + path to location where .csv containing predicted annotation + should be saved. Defaults to current working directory. + min_segment_dur : float + minimum duration of segment, in seconds. If specified, then + any segment with a duration less than min_segment_dur is + removed from lbl_tb. Default is None, in which case no + segments are removed. + majority_vote : bool + if True, transform segments containing multiple labels + into segments with a single label by taking a "majority vote", + i.e. assign all time bins in the segment the most frequently + occurring label in the segment. This transform can only be + applied if the labelmap contains an 'unlabeled' label, + because unlabeled segments makes it possible to identify + the labeled segments. Default is False. save_net_outputs : bool - if True, save 'raw' outputs of neural networks - before they are converted to annotations. Default is False. - Typically the output will be "logits" - to which a softmax transform might be applied. - For each item in the dataset--each row in the `dataset_path` .csv-- - the output will be saved in a separate file in `output_dir`, - with the extension `{MODEL_NAME}.output.npz`. E.g., if the input is a - spectrogram with `spect_path` filename `gy6or6_032312_081416.npz`, - and the network is `TweetyNet`, then the net output file - will be `gy6or6_032312_081416.tweetynet.output.npz`. + If True, save 'raw' outputs of neural networks + before they are converted to annotations. Default is False. + Typically the output will be "logits" + to which a softmax transform might be applied. + For each item in the dataset--each row in the `dataset_path` .csv-- + the output will be saved in a separate file in `output_dir`, + with the extension `{MODEL_NAME}.output.npz`. E.g., if the input is a + spectrogram with `spect_path` filename `gy6or6_032312_081416.npz`, + and the network is `TweetyNet`, then the net output file + will be `gy6or6_032312_081416.tweetynet.output.npz`. """ # required, external files @@ -97,6 +99,9 @@ class PredictConfig: dataset: DatasetConfig = field( validator=instance_of(DatasetConfig), ) + trainer: TrainerConfig = field( + validator=instance_of(TrainerConfig), + ) # optional, transform spect_scaler_path = field( @@ -106,7 +111,6 @@ class PredictConfig: # optional, data loader num_workers = field(validator=instance_of(int), default=2) - device = field(validator=instance_of(str), default=device.get_default()) annot_csv_filename = field( validator=validators.optional(instance_of(str)), default=None @@ -133,7 +137,7 @@ def from_config_dict(cls, config_dict: dict) -> PredictConfig: for required_key in REQUIRED_KEYS: if required_key not in config_dict: raise KeyError( - "The `[vak.eval]` table in a configuration file requires " + "The `[vak.predict]` table in a configuration file requires " f"the option '{required_key}', but it was not found " "when loading the configuration file into a Python dictionary. " "Please check that the configuration file is formatted correctly." @@ -144,4 +148,5 @@ def from_config_dict(cls, config_dict: dict) -> PredictConfig: config_dict["model"] = ModelConfig.from_config_dict( config_dict["model"] ) + config_dict["trainer"] = TrainerConfig(**config_dict["trainer"]) return cls(**config_dict) diff --git a/src/vak/config/train.py b/src/vak/config/train.py index 4c997a8c1..08b8a7d49 100644 --- a/src/vak/config/train.py +++ b/src/vak/config/train.py @@ -3,12 +3,18 @@ from attrs import converters, define, field, validators from attrs.validators import instance_of -from ..common import device from ..common.converters import bool_from_str, expanded_user_path from .dataset import DatasetConfig from .model import ModelConfig +from .trainer import TrainerConfig -REQUIRED_KEYS = ("dataset", "model", "root_results_dir") + +REQUIRED_KEYS = ( + "dataset", + "model", + "root_results_dir", + "trainer", +) @define @@ -21,11 +27,6 @@ class TrainConfig: The model to use: its name, and the parameters to configure it. Must be an instance of :class:`vak.config.ModelConfig` - dataset : vak.config.DatasetConfig - The dataset to use: the path to it, - and optionally a path to a file representing splits, - and the name, if it is a built-in dataset. - Must be an instance of :class:`vak.config.DatasetConfig`. num_epochs : int number of training epochs. One epoch = one iteration through the entire training set. @@ -35,12 +36,17 @@ class TrainConfig: directory in which results will be created. The vak.cli.train function will create a subdirectory in this directory each time it runs. + dataset : vak.config.DatasetConfig + The dataset to use: the path to it, + and optionally a path to a file representing splits, + and the name, if it is a built-in dataset. + Must be an instance of :class:`vak.config.DatasetConfig`. + trainer : vak.config.TrainerConfig + Configuration for :class:`lightning.pytorch.Trainer`. + Must be an instance of :class:`vak.config.TrainerConfig`. num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. - device : str - Device on which to work with model + data. - Defaults to 'cuda' if torch.cuda.is_available is True. shuffle: bool if True, shuffle training data before each epoch. Default is True. normalize_spectrograms : bool @@ -81,6 +87,9 @@ class TrainConfig: dataset: DatasetConfig = field( validator=instance_of(DatasetConfig), ) + trainer: TrainerConfig = field( + validator=instance_of(TrainerConfig), + ) results_dirname = field( converter=converters.optional(expanded_user_path), @@ -94,7 +103,6 @@ class TrainConfig: ) num_workers = field(validator=instance_of(int), default=2) - device = field(validator=instance_of(str), default=device.get_default()) shuffle = field( converter=bool_from_str, validator=instance_of(bool), default=True ) @@ -130,7 +138,7 @@ def from_config_dict(cls, config_dict: dict) -> "TrainConfig": The :class:`dict` passed in should be the one found by loading a valid configuration toml file with :func:`vak.config.parse.from_toml_path`, - and then using key ``prep``, + and then using key ``train``, i.e., ``TrainConfig.from_config_dict(config_dict['train'])``.""" for required_key in REQUIRED_KEYS: if required_key not in config_dict: @@ -146,4 +154,5 @@ def from_config_dict(cls, config_dict: dict) -> "TrainConfig": config_dict["dataset"] = DatasetConfig.from_config_dict( config_dict["dataset"] ) + config_dict["trainer"] = TrainerConfig(**config_dict["trainer"]) return cls(**config_dict) diff --git a/src/vak/config/trainer.py b/src/vak/config/trainer.py new file mode 100644 index 000000000..b2e751905 --- /dev/null +++ b/src/vak/config/trainer.py @@ -0,0 +1,117 @@ + +from __future__ import annotations + +from attrs import asdict, define, field, validators + +from .. import common + + +def is_valid_accelerator(instance, attribute, value): + """Check if ``accelerator`` is valid""" + if value == "auto": + raise ValueError( + "Using the 'auto' value for the `lightning.pytorch.Trainer` parameter `accelerator` currently " + "breaks functionality for the command-line interface of `vak`. " + "Please see this issue: https://github.com/vocalpy/vak/issues/691" + "If you need to use the 'auto' mode of `lightning.pytorch.Trainer`, please use `vak` directly in a script." + ) + elif value in ("cpu", "gpu", "tpu", "ipu"): + return + else: + raise ValueError( + f"Invalid value for 'accelerator' key in 'trainer' table of configuration file: {value}. " + "Value must be one of: {\"cpu\", \"gpu\", \"tpu\", \"ipu\"}" + ) + + +def is_valid_devices(instance, attribute, value): + """Check if ``devices`` is valid""" + if not ( + (isinstance(value, int)) or + (isinstance(value, list) and all([isinstance(el, int) for el in value])) + ): + raise ValueError( + "Invalid value for 'devices' key in 'trainer' table of configuration file: {value}" + ) + + +@define +class TrainerConfig: + """Class that represents ``trainer`` sub-table + in a toml configuration file. + + Used to configure :class:`lightning.pytorch.Trainer`. + + Attributes + ---------- + accelerator : str + Value for the `accelerator` argument to :class:`lightning.pytorch.Trainer`. + Default is the return value of :func:`vak.common.accelerator.get_default`. + devices: int, list of int + Number of devices (int) or exact device(s) (list of int) to use. + + Notes + ----- + Using the 'auto' value for the `lightning.pytorch.Trainer` parameter `accelerator` currently + breaks functionality for the command-line interface of `vak`. + Please see this issue: https://github.com/vocalpy/vak/issues/691 + If you need to use the 'auto' mode of `lightning.pytorch.Trainer`, please use `vak` directly in a script. + + Likewise, setting a value for the `lightning.pytorch.Trainer` parameter `devices` that is not either 1 + (meaning \"use a single GPU\") or a list with a single number (meaning \"use this exact GPU\") currently + breaks functionality for the command-line interface of `vak`. + Please see this issue: https://github.com/vocalpy/vak/issues/691 + If you need to use multiple GPUs, please use `vak` directly in a script. + """ + accelerator: str = field( + validator=is_valid_accelerator, + default=common.accelerator.get_default() + ) + devices: int | list[int] = field( + validator=validators.optional(is_valid_devices), + # for devices, we need to look at accelerator in post-init to determine default + default=None, + ) + + def __attrs_post_init__(self): + # set default self.devices *before* we validate, + # so that we don't throw error because of the default None + # that we need to change here depending on the value of self.accelerator + if self.devices is None: + if self.accelerator == "cpu": + # ~"use all available" + self.devices = 1 + elif self.accelerator in ("gpu", "tpu", "ipu"): + # we can only use a single device, assume there's only one + self.devices = [0] + + if self.accelerator in ("gpu", "tpu", "ipu"): + if not ( + (isinstance(self.devices, int) and self.devices == 1) or + (isinstance(self.devices, list) and len(self.devices) == 1 and all([isinstance(el, int) for el in self.devices])) + ): + raise ValueError( + "Setting a value for the `lightning.pytorch.Trainer` parameter `devices` that is not either 1 " + "(meaning \"use a single GPU\") or a list with a single number (meaning \"use this exact GPU\") currently " + "breaks functionality for the command-line interface of `vak`. " + "Please see this issue: https://github.com/vocalpy/vak/issues/691" + "If you need to use multiple GPUs, please use `vak` directly in a script." + ) + elif self.accelerator == "cpu": + if isinstance(self.devices, list): + raise ValueError( + f"Value for `devices` cannot be a list when `accelerator` is `cpu`. Value was: {self.devices}\n" + "When `accelerator` is `cpu`, please set `devices` to 1 or -1 (which are equivalent)." + ) + if self.devices < 1: + raise ValueError( + f"When value for 'accelerator' is 'cpu', value for `devices` should be an int > 0, but was: {self.devices}" + ) + + def asdict(self): + """Convert this :class:`TrainerConfig` instance + to a :class:`dict` that can be passed + into functions that take a ``trainer_config`` argument, + like :func:`vak.train` and :func:`vak.predict`. + """ + return asdict(self) diff --git a/src/vak/config/valid-version-1.0.toml b/src/vak/config/valid-version-1.1.toml similarity index 93% rename from src/vak/config/valid-version-1.0.toml rename to src/vak/config/valid-version-1.1.toml index ded6aa6ae..95aa50372 100644 --- a/src/vak/config/valid-version-1.0.toml +++ b/src/vak/config/valid-version-1.1.toml @@ -36,7 +36,6 @@ audio_path_key = 'audio_path' [vak.train] root_results_dir = './tests/test_data/results/train' num_workers = 4 -device = 'cuda' batch_size = 11 num_epochs = 2 normalize_spectrograms = true @@ -56,13 +55,16 @@ params = {window_size = 2000} [vak.train.model.TweetyNet] +[vak.train.trainer] +accelerator = "gpu" +devices = [0] + [vak.eval] checkpoint_path = '/home/user/results_181014_194418/TweetyNet/checkpoints/' labelmap_path = '/home/user/results_181014_194418/labelmap.json' output_dir = './tests/test_data/prep/learncurve' batch_size = 11 num_workers = 4 -device = 'cuda' spect_scaler_path = '/home/user/results_181014_194418/spect_scaler' post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01} @@ -73,6 +75,10 @@ splits_path = 'tests/test_data/prep/train/032312_prep_191224_225912.splits.json' [vak.eval.model.TweetyNet] +[vak.eval.trainer] +accelerator = "gpu" +devices = [0] + [vak.learncurve] root_results_dir = './tests/test_data/results/learncurve' batch_size = 11 @@ -85,7 +91,6 @@ patience = 4 results_dir_made_by_main_script = '/some/path/to/learncurve/' post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01} num_workers = 4 -device = 'cuda' [vak.learncurve.dataset] name = 'IntlDistributedSongbirdConsortiumPack' @@ -95,6 +100,10 @@ params = {window_size = 2000} [vak.learncurve.model.TweetyNet] +[vak.learncurve.trainer] +accelerator = "gpu" +devices = [0] + [vak.predict] checkpoint_path = '/home/user/results_181014_194418/TweetyNet/checkpoints/' labelmap_path = '/home/user/results_181014_194418/labelmap.json' @@ -102,7 +111,6 @@ annot_csv_filename = '032312_prep_191224_225910.annot.csv' output_dir = './tests/test_data/prep/learncurve' batch_size = 11 num_workers = 4 -device = 'cuda' spect_scaler_path = '/home/user/results_181014_194418/spect_scaler' min_segment_dur = 0.004 majority_vote = false @@ -114,4 +122,8 @@ path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' splits_path = 'tests/test_data/prep/train/032312_prep_191224_225912.splits.json' params = {window_size = 2000} -[vak.predict.model.TweetyNet] \ No newline at end of file +[vak.predict.model.TweetyNet] + +[vak.predict.trainer] +accelerator = "gpu" +devices = [0] diff --git a/src/vak/config/validators.py b/src/vak/config/validators.py index 628656a51..da396b8d3 100644 --- a/src/vak/config/validators.py +++ b/src/vak/config/validators.py @@ -59,7 +59,7 @@ def is_spect_format(instance, attribute, value): CONFIG_DIR = pathlib.Path(__file__).parent -VALID_TOML_PATH = CONFIG_DIR.joinpath("valid-version-1.0.toml") +VALID_TOML_PATH = CONFIG_DIR.joinpath("valid-version-1.1.toml") with VALID_TOML_PATH.open("r") as fp: VALID_DICT = tomlkit.load(fp)["vak"] VALID_TOP_LEVEL_TABLES = list(VALID_DICT.keys()) diff --git a/src/vak/eval/eval_.py b/src/vak/eval/eval_.py index fa1209f1d..665ec7d0d 100644 --- a/src/vak/eval/eval_.py +++ b/src/vak/eval/eval_.py @@ -16,6 +16,7 @@ def eval( model_config: dict, dataset_config: dict, + trainer_config: dict, checkpoint_path: str | pathlib.Path, output_dir: str | pathlib.Path, num_workers: int, @@ -36,6 +37,9 @@ def eval( dataset_config: dict Dataset configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. + trainer_config: dict + Configuration for :class:`lightning.pytorch.Trainer`. + Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`. checkpoint_path : str, pathlib.Path path to directory with checkpoint files saved by Torch, to reload model output_dir : str, pathlib.Path @@ -111,25 +115,25 @@ def eval( eval_frame_classification_model( model_config=model_config, dataset_config=dataset_config, + trainer_config=trainer_config, checkpoint_path=checkpoint_path, labelmap_path=labelmap_path, output_dir=output_dir, num_workers=num_workers, split=split, spect_scaler_path=spect_scaler_path, - device=device, post_tfm_kwargs=post_tfm_kwargs, ) elif model_family == "ParametricUMAPModel": eval_parametric_umap_model( model_config=model_config, dataset_config=dataset_config, + trainer_config=trainer_config, checkpoint_path=checkpoint_path, output_dir=output_dir, batch_size=batch_size, num_workers=num_workers, split=split, - device=device, ) else: raise ValueError(f"Model family not recognized: {model_family}") diff --git a/src/vak/eval/frame_classification.py b/src/vak/eval/frame_classification.py index 9757287e8..60be44b3e 100644 --- a/src/vak/eval/frame_classification.py +++ b/src/vak/eval/frame_classification.py @@ -10,7 +10,7 @@ import joblib import pandas as pd -import pytorch_lightning as lightning +import lightning import torch.utils.data from .. import datasets, models, transforms @@ -23,6 +23,7 @@ def eval_frame_classification_model( model_config: dict, dataset_config: dict, + trainer_config: dict, checkpoint_path: str | pathlib.Path, labelmap_path: str | pathlib.Path, output_dir: str | pathlib.Path, @@ -30,7 +31,6 @@ def eval_frame_classification_model( split: str = "test", spect_scaler_path: str | pathlib.Path = None, post_tfm_kwargs: dict | None = None, - device: str | None = None, ) -> None: """Evaluate a trained model. @@ -42,6 +42,9 @@ def eval_frame_classification_model( dataset_config: dict Dataset configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. + trainer_config: dict + Configuration for :class:`lightning.pytorch.Trainer`. + Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`. checkpoint_path : str, pathlib.Path Path to directory with checkpoint files saved by Torch, to reload model output_dir : str, pathlib.Path @@ -71,18 +74,15 @@ def eval_frame_classification_model( a float value for ``min_segment_dur``. See the docstring of the transform for more details on these arguments and how they work. - device : str - Device on which to work with model + data. - Defaults to 'cuda' if torch.cuda.is_available is True. Notes ----- - Note that unlike ``core.predict``, this function + Note that unlike :func:`core.predict`, this function can modify ``labelmap`` so that metrics like edit distance are correctly computed, by converting any string labels in ``labelmap`` with multiple characters to (mock) single-character labels, - with ``vak.labels.multi_char_labels_to_single_char``. + with :func:`vak.labels.multi_char_labels_to_single_char`. """ # ---- pre-conditions ---------------------------------------------------------------------------------------------- for path, path_name in zip( @@ -190,14 +190,12 @@ def eval_frame_classification_model( model.load_state_dict_from_path(checkpoint_path) - # TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691 - if device == "cuda": - accelerator = "gpu" - else: - accelerator = "auto" - - trainer_logger = lightning.loggers.TensorBoardLogger(save_dir=output_dir) - trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger) + trainer_logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=output_dir) + trainer = lightning.pytorch.Trainer( + accelerator=trainer_config["accelerator"], + devices=trainer_config["devices"], + logger=trainer_logger + ) # TODO: check for hasattr(model, test_step) and if so run test # below, [0] because validate returns list of dicts, length of no. of val loaders metric_vals = trainer.validate(model, dataloaders=val_loader)[0] diff --git a/src/vak/eval/parametric_umap.py b/src/vak/eval/parametric_umap.py index 107d8d844..cf4b13f41 100644 --- a/src/vak/eval/parametric_umap.py +++ b/src/vak/eval/parametric_umap.py @@ -8,7 +8,7 @@ from datetime import datetime import pandas as pd -import pytorch_lightning as lightning +import lightning import torch.utils.data from .. import models, transforms @@ -25,8 +25,8 @@ def eval_parametric_umap_model( output_dir: str | pathlib.Path, batch_size: int, num_workers: int, + trainer_config: dict, split: str = "test", - device: str | None = None, ) -> None: """Evaluate a trained model. @@ -44,15 +44,15 @@ def eval_parametric_umap_model( Path to location where .csv files with evaluation metrics should be saved. batch_size : int Number of samples per batch fed into model. + trainer_config: dict + Configuration for :class:`lightning.pytorch.Trainer`. + Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`. num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2. split : str Split of dataset on which model should be evaluated. One of {'train', 'val', 'test'}. Default is 'test'. - device : str - Device on which to work with model + data. - Defaults to 'cuda' if torch.cuda.is_available is True. """ # ---- pre-conditions ---------------------------------------------------------------------------------------------- for path, path_name in zip( @@ -111,14 +111,12 @@ def eval_parametric_umap_model( model.load_state_dict_from_path(checkpoint_path) - # TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691 - if device == "cuda": - accelerator = "gpu" - else: - accelerator = "auto" - - trainer_logger = lightning.loggers.TensorBoardLogger(save_dir=output_dir) - trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger) + trainer_logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=output_dir) + trainer = lightning.pytorch.Trainer( + accelerator=trainer_config["accelerator"], + devices=trainer_config["devices"], + logger=trainer_logger + ) # TODO: check for hasattr(model, test_step) and if so run test # below, [0] because validate returns list of dicts, length of no. of val loaders metric_vals = trainer.validate(model, dataloaders=val_loader)[0] diff --git a/src/vak/learncurve/frame_classification.py b/src/vak/learncurve/frame_classification.py index 8ca2e11b7..fc95d59bb 100644 --- a/src/vak/learncurve/frame_classification.py +++ b/src/vak/learncurve/frame_classification.py @@ -19,6 +19,7 @@ def learning_curve_for_frame_classification_model( model_config: dict, dataset_config: dict, + trainer_config: dict, batch_size: int, num_epochs: int, num_workers: int, @@ -29,7 +30,6 @@ def learning_curve_for_frame_classification_model( val_step: int | None = None, ckpt_step: int | None = None, patience: int | None = None, - device: str | None = None, ) -> None: """Generate results for a learning curve, where model performance is measured as a @@ -49,6 +49,9 @@ def learning_curve_for_frame_classification_model( dataset_config: dict Dataset configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. + trainer_config: dict + Configuration for :class:`lightning.pytorch.Trainer`. + Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`. dataset_path : str path to where dataset was saved as a csv. batch_size : int @@ -188,6 +191,7 @@ def learning_curve_for_frame_classification_model( train_frame_classification_model( model_config, dataset_config, + trainer_config, batch_size, num_epochs, num_workers, @@ -197,7 +201,6 @@ def learning_curve_for_frame_classification_model( val_step=val_step, ckpt_step=ckpt_step, patience=patience, - device=device, subset=subset, ) @@ -238,6 +241,7 @@ def learning_curve_for_frame_classification_model( eval_frame_classification_model( model_config, dataset_config, + trainer_config, ckpt_path, labelmap_path, results_path_this_replicate, @@ -245,7 +249,6 @@ def learning_curve_for_frame_classification_model( "test", spect_scaler_path, post_tfm_kwargs, - device, ) # ---- make a csv for analysis ------------------------------------------------------------------------------------- diff --git a/src/vak/learncurve/learncurve.py b/src/vak/learncurve/learncurve.py index 0b6e443bf..def4b722f 100644 --- a/src/vak/learncurve/learncurve.py +++ b/src/vak/learncurve/learncurve.py @@ -15,6 +15,7 @@ def learning_curve( model_config: dict, dataset_config: dict, + trainer_config: dict, batch_size: int, num_epochs: int, num_workers: int, @@ -25,7 +26,6 @@ def learning_curve( val_step: int | None = None, ckpt_step: int | None = None, patience: int | None = None, - device: str | None = None, ) -> None: """Generate results for a learning curve, where model performance is measured as a @@ -45,6 +45,9 @@ def learning_curve( dataset_config: dict Dataset configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. + trainer_config: dict + Configuration for :class:`lightning.pytorch.Trainer`. + Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`. batch_size : int number of samples per batch presented to models during training. num_epochs : int @@ -74,10 +77,6 @@ def learning_curve( a float value for ``min_segment_dur``. See the docstring of the transform for more details on these arguments and how they work. - device : str - Device on which to work with model + data. - Default is None. If None, then a device will be selected with vak.device.get_default. - That function defaults to 'cuda' if torch.cuda.is_available is True. shuffle: bool if True, shuffle training data before each epoch. Default is True. normalize_spectrograms : bool @@ -118,6 +117,7 @@ def learning_curve( learning_curve_for_frame_classification_model( model_config=model_config, dataset_config=dataset_config, + trainer_config=trainer_config, batch_size=batch_size, num_epochs=num_epochs, num_workers=num_workers, @@ -128,7 +128,6 @@ def learning_curve( val_step=val_step, ckpt_step=ckpt_step, patience=patience, - device=device, ) else: raise ValueError(f"Model family not recognized: {model_family}") diff --git a/src/vak/models/base.py b/src/vak/models/base.py index fd27b7ae3..42542099f 100644 --- a/src/vak/models/base.py +++ b/src/vak/models/base.py @@ -7,14 +7,14 @@ import inspect from typing import Callable, ClassVar -import pytorch_lightning as lightning +import lightning import torch from .definition import ModelDefinition from .definition import validate as validate_definition -class Model(lightning.LightningModule): +class Model(lightning.pytorch.LightningModule): """Base class for a model in ``vak``, that other families of models should subclass. @@ -286,7 +286,7 @@ def load_state_dict_from_path(self, ckpt_path): in that chekcpoint. This method allows loading a state dict into an instance. - It's necessary because `lightning.LightningModule.load`` is a + It's necessary because `lightning.pytorch.LightningModule.load`` is a ``classmethod``, so calling that method will trigger ``LightningModule.__init__`` instead of running ``vak.models.Model.__init__``. @@ -297,7 +297,7 @@ def load_state_dict_from_path(self, ckpt_path): Path to a checkpoint saved by a model in ``vak``. This checkpoint has the same key-value pairs as any other checkpoint saved by a - ``lightning.LightningModule``. + ``lightning.pytorch.LightningModule``. Returns ------- diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index 305018e8c..0b1777e80 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -160,7 +160,7 @@ def __init__( def configure_optimizers(self): """Returns the model's optimizer. - Method required by ``lightning.LightningModule``. + Method required by ``lightning.pytorch.LightningModule``. This method returns the ``optimizer`` instance passed into ``__init__``. If None was passed in, an instance that was created with default arguments will be returned. @@ -185,7 +185,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def training_step(self, batch: tuple, batch_idx: int): """Perform one training step. - Method required by ``lightning.LightningModule``. + Method required by ``lightning.pytorch.LightningModule``. Parameters ---------- @@ -209,7 +209,7 @@ def training_step(self, batch: tuple, batch_idx: int): def validation_step(self, batch: tuple, batch_idx: int): """Perform one validation step. - Method required by ``lightning.LightningModule``. + Method required by ``lightning.pytorch.LightningModule``. Logs metrics using ``self.log`` Parameters @@ -322,7 +322,7 @@ def validation_step(self, batch: tuple, batch_idx: int): def predict_step(self, batch: tuple, batch_idx: int): """Perform one prediction step. - Method required by ``lightning.LightningModule``. + Method required by ``lightning.pytorch.LightningModule``. Parameters ---------- diff --git a/src/vak/models/parametric_umap_model.py b/src/vak/models/parametric_umap_model.py index 4d2c2cb94..5abaf7bbf 100644 --- a/src/vak/models/parametric_umap_model.py +++ b/src/vak/models/parametric_umap_model.py @@ -11,7 +11,7 @@ import pathlib from typing import Callable, ClassVar, Type -import pytorch_lightning as lightning +import lightning import torch import torch.utils.data @@ -126,7 +126,7 @@ def from_config(cls, config: dict): ) -class ParametricUMAPDatamodule(lightning.LightningDataModule): +class ParametricUMAPDatamodule(lightning.pytorch.LightningDataModule): def __init__( self, dataset, @@ -178,7 +178,7 @@ def __init__( def fit( self, - trainer: lightning.Trainer, + trainer: lightning.pytorch.Trainer, dataset_path: str | pathlib.Path, transform=None, ): diff --git a/src/vak/predict/frame_classification.py b/src/vak/predict/frame_classification.py index 765fb3134..e5a3cb6f7 100644 --- a/src/vak/predict/frame_classification.py +++ b/src/vak/predict/frame_classification.py @@ -10,13 +10,12 @@ import crowsetta import joblib import numpy as np -import pytorch_lightning as lightning +import lightning import torch.utils.data from tqdm import tqdm from .. import datasets, models, transforms from ..common import constants, files, validators -from ..common.device import get_default as get_default_device from ..datasets.frame_classification import FramesDataset logger = logging.getLogger(__name__) @@ -25,12 +24,12 @@ def predict_with_frame_classification_model( model_config: dict, dataset_config: dict, + trainer_config: dict, checkpoint_path, labelmap_path, num_workers=2, timebins_key="t", spect_scaler_path=None, - device=None, annot_csv_filename=None, output_dir=None, min_segment_dur=None, @@ -48,6 +47,9 @@ def predict_with_frame_classification_model( dataset_config: dict Dataset configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. + trainer_config: dict + Configuration for :class:`lightning.pytorch.Trainer`. + Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`. checkpoint_path : str path to directory with checkpoint files saved by Torch, to reload model labelmap_path : str @@ -59,9 +61,6 @@ def predict_with_frame_classification_model( key for accessing spectrogram in files. Default is 's'. timebins_key : str key for accessing vector of time bins in files. Default is 't'. - device : str - Device on which to work with model + data. - Defaults to 'cuda' if torch.cuda.is_available is True. spect_scaler_path : str path to a saved SpectScaler object used to normalize spectrograms. If spectrograms were normalized and this is not provided, will give @@ -124,9 +123,6 @@ def predict_with_frame_classification_model( f"value specified for output_dir is not recognized as a directory: {output_dir}" ) - if device is None: - device = get_default_device() - # ---------------- load data for prediction ------------------------------------------------------------------------ if spect_scaler_path: logger.info(f"loading SpectScaler from path: {spect_scaler_path}") @@ -226,14 +222,12 @@ def predict_with_frame_classification_model( ) model.load_state_dict_from_path(checkpoint_path) - # TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691 - if device == "cuda": - accelerator = "gpu" - else: - accelerator = "auto" - - trainer_logger = lightning.loggers.TensorBoardLogger(save_dir=output_dir) - trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger) + trainer_logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=output_dir) + trainer = lightning.pytorch.Trainer( + accelerator=trainer_config["accelerator"], + devices=trainer_config["devices"], + logger=trainer_logger + ) logger.info(f"running predict method of {model_name}") results = trainer.predict(model, pred_loader) diff --git a/src/vak/predict/parametric_umap.py b/src/vak/predict/parametric_umap.py index 4e54336f4..331b571b2 100644 --- a/src/vak/predict/parametric_umap.py +++ b/src/vak/predict/parametric_umap.py @@ -6,12 +6,11 @@ import os import pathlib -import pytorch_lightning as lightning +import lightning import torch.utils.data from .. import datasets, models, transforms from ..common import validators -from ..common.device import get_default as get_default_device from ..datasets.parametric_umap import ParametricUMAPDataset logger = logging.getLogger(__name__) @@ -20,12 +19,12 @@ def predict_with_parametric_umap_model( model_config: dict, dataset_config: dict, + trainer_config: dict, checkpoint_path, num_workers=2, transform_params: dict | None = None, dataset_params: dict | None = None, timebins_key="t", - device=None, output_dir=None, ): """Make predictions on a dataset with a trained @@ -39,6 +38,9 @@ def predict_with_parametric_umap_model( dataset_config: dict Dataset configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. + trainer_config: dict + Configuration for :class:`lightning.pytorch.Trainer`. + Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`. checkpoint_path : str path to directory with checkpoint files saved by Torch, to reload model num_workers : int @@ -54,9 +56,6 @@ def predict_with_parametric_umap_model( Optional, default is None. timebins_key : str key for accessing vector of time bins in files. Default is 't'. - device : str - Device on which to work with model + data. - Defaults to 'cuda' if torch.cuda.is_available is True. annot_csv_filename : str name of .csv file containing predicted annotations. Default is None, in which case the name of the dataset .csv @@ -97,9 +96,6 @@ def predict_with_parametric_umap_model( f"value specified for output_dir is not recognized as a directory: {output_dir}" ) - if device is None: - device = get_default_device() - # ---------------- load data for prediction ------------------------------------------------------------------------ model_name = model_config["name"] # TODO: fix this when we build transforms into datasets @@ -157,14 +153,12 @@ def predict_with_parametric_umap_model( ) model.load_state_dict_from_path(checkpoint_path) - # TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691 - if device == "cuda": - accelerator = "gpu" - else: - accelerator = "auto" - - trainer_logger = lightning.loggers.TensorBoardLogger(save_dir=output_dir) - trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger) + trainer_logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=output_dir) + trainer = lightning.pytorch.Trainer( + accelerator=trainer_config["accelerator"], + devices=trainer_config["devices"], + logger=trainer_logger + ) logger.info(f"running predict method of {model_name}") results = trainer.predict(model, pred_loader) # noqa : F841 diff --git a/src/vak/predict/predict_.py b/src/vak/predict/predict_.py index 60373b11d..2cc60ea61 100644 --- a/src/vak/predict/predict_.py +++ b/src/vak/predict/predict_.py @@ -8,7 +8,7 @@ from .. import models from ..common import validators -from ..common.device import get_default as get_default_device +from ..common.accelerator import get_default as get_default_device from .frame_classification import predict_with_frame_classification_model logger = logging.getLogger(__name__) @@ -17,6 +17,7 @@ def predict( model_config: dict, dataset_config: dict, + trainer_config: dict, checkpoint_path: str | pathlib.Path, labelmap_path: str | pathlib.Path, num_workers: int = 2, @@ -37,8 +38,12 @@ def predict( Model configuration in a ``dict``, as loaded from a .toml file, and used by the model method ``from_config``. - dataset_path : str - Path to dataset, e.g., a csv file generated by running ``vak prep``. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. + trainer_config: dict + Configuration for :class:`lightning.pytorch.Trainer`. + Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`. checkpoint_path : str path to directory with checkpoint files saved by Torch, to reload model labelmap_path : str @@ -51,9 +56,6 @@ def predict( Argument to torch.DataLoader. Default is 2. timebins_key : str key for accessing vector of time bins in files. Default is 't'. - device : str - Device on which to work with model + data. - Defaults to 'cuda' if torch.cuda.is_available is True. spect_scaler_path : str path to a saved SpectScaler object used to normalize spectrograms. If spectrograms were normalized and this is not provided, will give @@ -130,12 +132,12 @@ def predict( predict_with_frame_classification_model( model_config=model_config, dataset_config=dataset_config, + trainer_config=trainer_config, checkpoint_path=checkpoint_path, labelmap_path=labelmap_path, num_workers=num_workers, timebins_key=timebins_key, spect_scaler_path=spect_scaler_path, - device=device, annot_csv_filename=annot_csv_filename, output_dir=output_dir, min_segment_dur=min_segment_dur, diff --git a/src/vak/prep/frame_classification/make_splits.py b/src/vak/prep/frame_classification/make_splits.py index 2af4b586d..41d521013 100644 --- a/src/vak/prep/frame_classification/make_splits.py +++ b/src/vak/prep/frame_classification/make_splits.py @@ -228,7 +228,7 @@ def make_splits( network model. As returned by :func:`vak.labels.to_map`. audio_format : str A :class:`string` representing the format of audio files. - One of :constant:`vak.common.constants.VALID_AUDIO_FORMATS`. + One of :const:`vak.common.constants.VALID_AUDIO_FORMATS`. spect_key : str Key for accessing spectrogram in files. Default is 's'. timebins_key : str diff --git a/src/vak/train/frame_classification.py b/src/vak/train/frame_classification.py index 25e07a062..e2d9074f6 100644 --- a/src/vak/train/frame_classification.py +++ b/src/vak/train/frame_classification.py @@ -14,7 +14,6 @@ from .. import datasets, models, transforms from ..common import validators -from ..common.device import get_default as get_default_device from ..common.trainer import get_default_trainer from ..datasets.frame_classification import FramesDataset, WindowDataset @@ -29,6 +28,7 @@ def get_split_dur(df: pd.DataFrame, split: str) -> float: def train_frame_classification_model( model_config: dict, dataset_config: dict, + trainer_config: dict, batch_size: int, num_epochs: int, num_workers: int, @@ -40,7 +40,6 @@ def train_frame_classification_model( val_step: int | None = None, ckpt_step: int | None = None, patience: int | None = None, - device: str | None = None, subset: str | None = None, ) -> None: """Train a model from the frame classification family @@ -60,6 +59,9 @@ def train_frame_classification_model( dataset_config: dict Dataset configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. + trainer_config: dict + Configuration for :class:`lightning.pytorch.Trainer` in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`. batch_size : int number of samples per batch presented to models during training. num_epochs : int @@ -86,10 +88,6 @@ def train_frame_classification_model( results_path : str, pathlib.Path Directory where results will be saved. If specified, this parameter overrides ``root_results_dir``. - device : str - Device on which to work with model + data. - Default is None. If None, then a device will be selected with vak.split.get_default. - That function defaults to 'cuda' if torch.cuda.is_available is True. shuffle: bool if True, shuffle training data before each epoch. Default is True. normalize_spectrograms : bool @@ -279,9 +277,6 @@ def train_frame_classification_model( else: val_loader = None - if device is None: - device = get_default_device() - model = models.get( model_name, model_config, @@ -308,11 +303,12 @@ def train_frame_classification_model( "patience": patience, } trainer = get_default_trainer( + accelerator=trainer_config["accelerator"], + devices=trainer_config["devices"], max_steps=max_steps, log_save_dir=results_model_root, val_step=val_step, default_callback_kwargs=default_callback_kwargs, - device=device, ) train_time_start = datetime.datetime.now() logger.info(f"Training start time: {train_time_start.isoformat()}") diff --git a/src/vak/train/parametric_umap.py b/src/vak/train/parametric_umap.py index f9180e5c0..0fd72ca33 100644 --- a/src/vak/train/parametric_umap.py +++ b/src/vak/train/parametric_umap.py @@ -7,12 +7,11 @@ import pathlib import pandas as pd -import pytorch_lightning as lightning +import lightning import torch.utils.data from .. import datasets, models, transforms from ..common import validators -from ..common.device import get_default as get_default_device from ..common.paths import generate_results_dir_name_as_path from ..datasets.parametric_umap import ParametricUMAPDataset @@ -25,22 +24,17 @@ def get_split_dur(df: pd.DataFrame, split: str) -> float: def get_trainer( + accelerator: str, + devices: int | list[int], max_epochs: int, ckpt_root: str | pathlib.Path, ckpt_step: int, log_save_dir: str | pathlib.Path, - device: str = "cuda", -) -> lightning.Trainer: - """Returns an instance of ``lightning.Trainer`` +) -> lightning.pytorch.Trainer: + """Returns an instance of ``lightning.pytorch.Trainer`` with a default set of callbacks. Used by ``vak.core`` functions.""" - # TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691 - if device == "cuda": - accelerator = "gpu" - else: - accelerator = "auto" - - ckpt_callback = lightning.callbacks.ModelCheckpoint( + ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint( dirpath=ckpt_root, filename="checkpoint", every_n_train_steps=ckpt_step, @@ -50,7 +44,7 @@ def get_trainer( ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint" ckpt_callback.FILE_EXTENSION = ".pt" - val_ckpt_callback = lightning.callbacks.ModelCheckpoint( + val_ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint( monitor="val_loss", dirpath=ckpt_root, save_top_k=1, @@ -66,11 +60,12 @@ def get_trainer( val_ckpt_callback, ] - logger = lightning.loggers.TensorBoardLogger(save_dir=log_save_dir) + logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=log_save_dir) - trainer = lightning.Trainer( + trainer = lightning.pytorch.Trainer( max_epochs=max_epochs, accelerator=accelerator, + devices=devices, logger=logger, callbacks=callbacks, ) @@ -80,6 +75,7 @@ def get_trainer( def train_parametric_umap_model( model_config: dict, dataset_config: dict, + trainer_config: dict, batch_size: int, num_epochs: int, num_workers: int, @@ -89,7 +85,6 @@ def train_parametric_umap_model( shuffle: bool = True, val_step: int | None = None, ckpt_step: int | None = None, - device: str | None = None, subset: str | None = None, ) -> None: """Train a model from the parametric UMAP family @@ -109,6 +104,9 @@ def train_parametric_umap_model( dataset_config: dict Dataset configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. + trainer_config: dict + Configuration for :class:`lightning.pytorch.Trainer` in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`. batch_size : int number of samples per batch presented to models during training. num_epochs : int @@ -137,10 +135,6 @@ def train_parametric_umap_model( If ckpt_step is n, then a checkpoint is saved every time the global step / n is a whole number, i.e., when ckpt_step modulo the global step is 0. Default is None, in which case checkpoint is only saved at the last epoch. - device : str - Device on which to work with model + data. - Default is None. If None, then a device will be selected with vak.split.get_default. - That function defaults to 'cuda' if torch.cuda.is_available is True. shuffle: bool if True, shuffle training data before each epoch. Default is True. split : str @@ -248,9 +242,6 @@ def train_parametric_umap_model( else: val_loader = None - if device is None: - device = get_default_device() - model = models.get( model_name, model_config, @@ -269,9 +260,10 @@ def train_parametric_umap_model( ckpt_root.mkdir() logger.info(f"training {model_name}") trainer = get_trainer( + accelerator=trainer_config["accelerator"], + devices=trainer_config["devices"], max_epochs=num_epochs, log_save_dir=results_model_root, - device=device, ckpt_root=ckpt_root, ckpt_step=ckpt_step, ) diff --git a/src/vak/train/train_.py b/src/vak/train/train_.py index 96926967d..1a5527c88 100644 --- a/src/vak/train/train_.py +++ b/src/vak/train/train_.py @@ -16,6 +16,7 @@ def train( model_config: dict, dataset_config: dict, + trainer_config: dict, batch_size: int, num_epochs: int, num_workers: int, @@ -27,7 +28,6 @@ def train( val_step: int | None = None, ckpt_step: int | None = None, patience: int | None = None, - device: str | None = None, subset: str | None = None, ): """Train a model and save results. @@ -44,6 +44,9 @@ def train( dataset_config: dict Dataset configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. + trainer_config: dict + Configuration for :class:`lightning.pytorch.Trainer` in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`. window_size : int size of windows taken from spectrograms, in number of time bins, shown to neural networks @@ -147,6 +150,7 @@ def train( train_frame_classification_model( model_config=model_config, dataset_config=dataset_config, + trainer_config=trainer_config, batch_size=batch_size, num_epochs=num_epochs, num_workers=num_workers, @@ -158,13 +162,13 @@ def train( val_step=val_step, ckpt_step=ckpt_step, patience=patience, - device=device, subset=subset, ) elif model_family == "ParametricUMAPModel": train_parametric_umap_model( model_config=model_config, dataset_config=dataset_config, + trainer_config=trainer_config, batch_size=batch_size, num_epochs=num_epochs, num_workers=num_workers, @@ -173,7 +177,6 @@ def train( shuffle=shuffle, val_step=val_step, ckpt_step=ckpt_step, - device=device, subset=subset, ) else: diff --git a/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml index b38979f48..32940a679 100644 --- a/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml @@ -17,7 +17,7 @@ transform_type = "log_spect_plus_one" checkpoint_path = "tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/ConvEncoderUMAP/results_230727_210112/ConvEncoderUMAP/checkpoints/checkpoint.pt" batch_size = 64 num_workers = 16 -device = "cuda" + output_dir = "./tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/ConvEncoderUMAP" [vak.eval.model.ConvEncoderUMAP.network] @@ -31,3 +31,7 @@ n_components = 2 [vak.eval.model.ConvEncoderUMAP.optimizer] lr = 0.001 + +[vak.eval.trainer] +accelerator = "gpu" +devices = [0] \ No newline at end of file diff --git a/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml index 8be5a4d3a..f188b650c 100644 --- a/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml @@ -21,7 +21,7 @@ num_epochs = 1 val_step = 1 ckpt_step = 1000 num_workers = 16 -device = "cuda" + root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/ConvEncoderUMAP" [vak.train.model.ConvEncoderUMAP.network] @@ -35,3 +35,7 @@ n_components = 2 [vak.train.model.ConvEncoderUMAP.optimizer] lr = 0.001 + +[vak.train.trainer] +accelerator = "gpu" +devices = [0] diff --git a/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml index 12bfcba84..98e51a7b0 100644 --- a/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml @@ -19,7 +19,7 @@ checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRep labelmap_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json" batch_size = 11 num_workers = 16 -device = "cuda" + spect_scaler_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect" output_dir = "./tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet" @@ -43,3 +43,7 @@ hidden_size = 32 [vak.eval.model.TweetyNet.optimizer] lr = 0.001 + +[vak.eval.trainer] +accelerator = "gpu" +devices = [0] diff --git a/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml index 59868a28a..744cec82e 100644 --- a/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml @@ -27,7 +27,7 @@ val_step = 50 ckpt_step = 200 patience = 4 num_workers = 16 -device = "cuda" + root_results_dir = "./tests/data_for_tests/generated/results/learncurve/audio_cbin_annot_notmat/TweetyNet" [vak.learncurve.post_tfm_kwargs] @@ -50,3 +50,7 @@ hidden_size = 32 [vak.learncurve.model.TweetyNet.optimizer] lr = 0.001 + +[vak.learncurve.trainer] +accelerator = "gpu" +devices = [0] diff --git a/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml index 3d794f314..0d1cd8f13 100644 --- a/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml @@ -18,7 +18,7 @@ checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRep labelmap_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json" batch_size = 11 num_workers = 16 -device = "cuda" + output_dir = "./tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet" annot_csv_filename = "bl26lb16.041912.annot.csv" @@ -38,3 +38,7 @@ hidden_size = 32 [vak.predict.model.TweetyNet.optimizer] lr = 0.001 + +[vak.predict.trainer] +accelerator = "gpu" +devices = [0] diff --git a/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml index 9b751e7f0..c5d26bf88 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml @@ -25,7 +25,7 @@ val_step = 50 ckpt_step = 200 patience = 4 num_workers = 16 -device = "cuda" + root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet" [vak.train.dataset] @@ -44,3 +44,7 @@ hidden_size = 32 [vak.train.model.TweetyNet.optimizer] lr = 0.001 + +[vak.train.trainer] +accelerator = "gpu" +devices = [0] diff --git a/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml index c7ca91a96..5af7f78d2 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml @@ -25,7 +25,7 @@ val_step = 50 ckpt_step = 200 patience = 4 num_workers = 16 -device = "cuda" + root_results_dir = "./tests/data_for_tests/generated/results/train_continue/audio_cbin_annot_notmat/TweetyNet" checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" spect_scaler_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect" @@ -46,3 +46,7 @@ hidden_size = 32 [vak.train.model.TweetyNet.optimizer] lr = 0.001 + +[vak.train.trainer] +accelerator = "gpu" +devices = [0] diff --git a/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml index c66e9c34d..05d897eba 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml @@ -25,7 +25,7 @@ val_step = 50 ckpt_step = 200 patience = 4 num_workers = 16 -device = "cuda" + root_results_dir = "./tests/data_for_tests/generated/results/train_continue/spect_mat_annot_yarden/TweetyNet" checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" @@ -45,3 +45,7 @@ hidden_size = 32 [vak.train.model.TweetyNet.optimizer] lr = 0.001 + +[vak.train.trainer] +accelerator = "gpu" +devices = [0] diff --git a/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml index a9aaaf112..4796edb60 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml @@ -25,7 +25,7 @@ val_step = 50 ckpt_step = 200 patience = 4 num_workers = 16 -device = "cuda" + root_results_dir = "./tests/data_for_tests/generated/results/train/spect_mat_annot_yarden/TweetyNet" [vak.train.dataset] @@ -44,3 +44,7 @@ hidden_size = 32 [vak.train.model.TweetyNet.optimizer] lr = 0.001 + +[vak.train.trainer] +accelerator = "gpu" +devices = [0] diff --git a/tests/data_for_tests/configs/invalid_key_config.toml b/tests/data_for_tests/configs/invalid_key_config.toml index 0012c6d6c..76770862c 100644 --- a/tests/data_for_tests/configs/invalid_key_config.toml +++ b/tests/data_for_tests/configs/invalid_key_config.toml @@ -35,3 +35,7 @@ params = { window_size = 88 } [vak.train.model.TweetyNet.optimizer] learning_rate = 0.001 + +[vak.train.trainer] +accelerator = "gpu" +devices = [0] diff --git a/tests/data_for_tests/configs/invalid_table_config.toml b/tests/data_for_tests/configs/invalid_table_config.toml index 24998129d..09898375f 100644 --- a/tests/data_for_tests/configs/invalid_table_config.toml +++ b/tests/data_for_tests/configs/invalid_table_config.toml @@ -32,3 +32,7 @@ save_only_single_checkpoint_file = true [vak.train.model.TweetyNet.optimizer] learning_rate = 0.001 + +[vak.train.trainer] +accelerator = "gpu" +devices = [0] diff --git a/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml b/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml index a4fcd542d..3107b538c 100644 --- a/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml +++ b/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml @@ -27,9 +27,13 @@ val_step = 50 ckpt_step = 200 patience = 4 num_workers = 16 -device = "cuda" + root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat" +[vak.train.trainer] +accelerator = "gpu" +devices = [0] + [vak.learncurve] normalize_spectrograms = true batch_size = 11 @@ -40,7 +44,7 @@ patience = 4 num_workers = 16 train_set_durs = [ 4, 6 ] num_replicates = 2 -device = "cuda" + root_results_dir = './tests/data_for_tests/generated/results/learncurve/audio_cbin_annot_notmat' [vak.learncurve.model.TweetyNet.optimizer] diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index ac174ea00..18d506be7 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -6,6 +6,7 @@ from .dataframe import * from .dataset import * from .device import * +from .trainer import * from .model import * from .path import * from .source_files import * diff --git a/tests/fixtures/trainer.py b/tests/fixtures/trainer.py new file mode 100644 index 000000000..8d8becd6d --- /dev/null +++ b/tests/fixtures/trainer.py @@ -0,0 +1,19 @@ +import pytest +import torch + + +TRAINER_TABLE = [ + {"accelerator": "cpu", "devices": 1} +] +if torch.cuda.is_available(): + {"accelerator": "gpu", "devices": [0]} + + +@pytest.fixture(params=TRAINER_TABLE) +def trainer_table(request): + """Parametrized 'trainer' table for config file + + Causes any test using this :func:`trainer_table` fixture + to run just once if only a cpu is available, + and twice if ``torch.cuda.is_available()`` returns ``True``.""" + return request.param diff --git a/tests/test_cli/test_eval.py b/tests/test_cli/test_eval.py index 0ee9aba65..5a2cfdd90 100644 --- a/tests/test_cli/test_eval.py +++ b/tests/test_cli/test_eval.py @@ -18,7 +18,7 @@ ], ) def test_eval( - model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, device + model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, trainer_table ): output_dir = tmp_path.joinpath( f"test_eval_{audio_format}_{spect_format}_{annot_format}" @@ -27,7 +27,7 @@ def test_eval( keys_to_change = [ {"table": "eval", "key": "output_dir", "value": str(output_dir)}, - {"table": "eval", "key": "device", "value": device}, + {"table": "eval", "key": "trainer", "value": trainer_table}, ] toml_path = specific_config_toml_path( diff --git a/tests/test_cli/test_learncurve.py b/tests/test_cli/test_learncurve.py index 7fd0a3a8b..c249b4164 100644 --- a/tests/test_cli/test_learncurve.py +++ b/tests/test_cli/test_learncurve.py @@ -10,7 +10,7 @@ from . import cli_asserts -def test_learncurve(specific_config_toml_path, tmp_path, device): +def test_learncurve(specific_config_toml_path, tmp_path, trainer_table): root_results_dir = tmp_path.joinpath("test_learncurve_root_results_dir") root_results_dir.mkdir() @@ -20,7 +20,7 @@ def test_learncurve(specific_config_toml_path, tmp_path, device): "key": "root_results_dir", "value": str(root_results_dir), }, - {"table": "learncurve", "key": "device", "value": device}, + {"table": "learncurve", "key": "trainer", "value": trainer_table}, ] toml_path = specific_config_toml_path( diff --git a/tests/test_cli/test_predict.py b/tests/test_cli/test_predict.py index 30c78d3c5..d7766bcff 100644 --- a/tests/test_cli/test_predict.py +++ b/tests/test_cli/test_predict.py @@ -17,7 +17,7 @@ ], ) def test_predict( - model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, device + model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, trainer_table ): output_dir = tmp_path.joinpath( f"test_predict_{audio_format}_{spect_format}_{annot_format}" @@ -26,7 +26,7 @@ def test_predict( keys_to_change = [ {"table": "predict", "key": "output_dir", "value": str(output_dir)}, - {"table": "predict", "key": "device", "value": device}, + {"table": "predict", "key": "trainer", "value": trainer_table}, ] toml_path = specific_config_toml_path( diff --git a/tests/test_cli/test_train.py b/tests/test_cli/test_train.py index a23acab3c..21c699683 100644 --- a/tests/test_cli/test_train.py +++ b/tests/test_cli/test_train.py @@ -19,7 +19,7 @@ ], ) def test_train( - model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, device + model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, trainer_table ): root_results_dir = tmp_path.joinpath("test_train_root_results_dir") root_results_dir.mkdir() @@ -30,7 +30,7 @@ def test_train( "key": "root_results_dir", "value": str(root_results_dir), }, - {"table": "train", "key": "device", "value": device}, + {"table": "train", "key": "trainer", "value": trainer_table}, ] toml_path = specific_config_toml_path( diff --git a/tests/test_config/test_eval.py b/tests/test_config/test_eval.py index de0ac681e..7917e4a6b 100644 --- a/tests/test_config/test_eval.py +++ b/tests/test_config/test_eval.py @@ -14,7 +14,7 @@ class TestEval: 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', 'batch_size': 11, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', 'post_tfm_kwargs': { @@ -45,6 +45,7 @@ class TestEval: def test_init(self, config_dict): config_dict['model'] = vak.config.ModelConfig.from_config_dict(config_dict['model']) config_dict['dataset'] = vak.config.DatasetConfig.from_config_dict(config_dict['dataset']) + config_dict['trainer'] = vak.config.TrainerConfig(**config_dict['trainer']) eval_config = vak.config.EvalConfig(**config_dict) @@ -58,7 +59,7 @@ def test_init(self, config_dict): 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', 'batch_size': 11, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', 'post_tfm_kwargs': { @@ -108,7 +109,7 @@ def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict): 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', 'batch_size': 11, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', 'post_tfm_kwargs': { @@ -127,7 +128,7 @@ def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict): 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', 'batch_size': 11, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', 'post_tfm_kwargs': { @@ -158,7 +159,7 @@ def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict): 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', 'batch_size': 11, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', 'post_tfm_kwargs': { @@ -193,7 +194,7 @@ def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict): 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', 'batch_size': 11, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', 'post_tfm_kwargs': { 'majority_vote': True, 'min_segment_dur': 0.02 diff --git a/tests/test_config/test_learncurve.py b/tests/test_config/test_learncurve.py index 6d2d65270..fa21c152b 100644 --- a/tests/test_config/test_learncurve.py +++ b/tests/test_config/test_learncurve.py @@ -17,7 +17,7 @@ class TestLearncurveConfig: 'ckpt_step': 200, 'patience': 4, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, 'model': { @@ -47,6 +47,7 @@ class TestLearncurveConfig: def test_init(self, config_dict): config_dict['model'] = vak.config.ModelConfig.from_config_dict(config_dict['model']) config_dict['dataset'] = vak.config.DatasetConfig.from_config_dict(config_dict['dataset']) + config_dict['trainer'] = vak.config.TrainerConfig(**config_dict['trainer']) learncurve_config = vak.config.LearncurveConfig(**config_dict) @@ -63,7 +64,7 @@ def test_init(self, config_dict): 'ckpt_step': 200, 'patience': 4, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, 'model': { @@ -118,7 +119,7 @@ def test_from_config_dict_with_real_config(self, a_generated_learncurve_config_d 'ckpt_step': 200, 'patience': 4, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, 'dataset': { @@ -137,7 +138,7 @@ def test_from_config_dict_with_real_config(self, a_generated_learncurve_config_d 'ckpt_step': 200, 'patience': 4, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, 'model': { @@ -171,7 +172,7 @@ def test_from_config_dict_with_real_config(self, a_generated_learncurve_config_d 'ckpt_step': 200, 'patience': 4, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, 'model': { 'TweetyNet': { diff --git a/tests/test_config/test_predict.py b/tests/test_config/test_predict.py index 8d81dcf07..8dc77e7d8 100644 --- a/tests/test_config/test_predict.py +++ b/tests/test_config/test_predict.py @@ -15,7 +15,7 @@ class TestPredictConfig: 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json', 'batch_size': 11, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', 'annot_csv_filename': 'bl26lb16.041912.annot.csv', 'model': { @@ -43,6 +43,7 @@ class TestPredictConfig: def test_init(self, config_dict): config_dict['model'] = vak.config.ModelConfig.from_config_dict(config_dict['model']) config_dict['dataset'] = vak.config.DatasetConfig.from_config_dict(config_dict['dataset']) + config_dict['trainer'] = vak.config.TrainerConfig(**config_dict['trainer']) predict_config = vak.config.PredictConfig(**config_dict) @@ -57,7 +58,7 @@ def test_init(self, config_dict): 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json', 'batch_size': 11, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', 'annot_csv_filename': 'bl26lb16.041912.annot.csv', 'model': { @@ -104,7 +105,7 @@ def test_from_config_dict_with_real_config(self, a_generated_predict_config_dict 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json', 'batch_size': 11, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', 'annot_csv_filename': 'bl26lb16.041912.annot.csv', 'model': { @@ -137,7 +138,7 @@ def test_from_config_dict_with_real_config(self, a_generated_predict_config_dict 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json', 'batch_size': 11, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', 'annot_csv_filename': 'bl26lb16.041912.annot.csv', 'model': { @@ -167,7 +168,7 @@ def test_from_config_dict_with_real_config(self, a_generated_predict_config_dict 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json', 'batch_size': 11, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', 'annot_csv_filename': 'bl26lb16.041912.annot.csv', 'dataset': { diff --git a/tests/test_config/test_train.py b/tests/test_config/test_train.py index e5a3127da..970b1abb3 100644 --- a/tests/test_config/test_train.py +++ b/tests/test_config/test_train.py @@ -17,7 +17,7 @@ class TestTrainConfig: 'ckpt_step': 200, 'patience': 4, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', 'model': { 'TweetyNet': { @@ -46,6 +46,7 @@ class TestTrainConfig: def test_init(self, config_dict): config_dict['model'] = vak.config.ModelConfig.from_config_dict(config_dict['model']) config_dict['dataset'] = vak.config.DatasetConfig.from_config_dict(config_dict['dataset']) + config_dict['trainer'] = vak.config.TrainerConfig(**config_dict['trainer']) train_config = vak.config.TrainConfig(**config_dict) @@ -62,7 +63,7 @@ def test_init(self, config_dict): 'ckpt_step': 200, 'patience': 4, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', 'model': { 'TweetyNet': { @@ -112,7 +113,7 @@ def test_from_config_dict_with_real_config(self, a_generated_train_config_dict): 'ckpt_step': 200, 'patience': 4, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', 'dataset': { 'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819' @@ -129,7 +130,7 @@ def test_from_config_dict_with_real_config(self, a_generated_train_config_dict): 'ckpt_step': 200, 'patience': 4, 'num_workers': 16, - 'device': 'cuda', + 'trainer': {'accelerator': 'gpu', 'devices': [0]}, 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', 'model': { 'TweetyNet': { diff --git a/tests/test_config/test_trainer.py b/tests/test_config/test_trainer.py new file mode 100644 index 000000000..649bae96e --- /dev/null +++ b/tests/test_config/test_trainer.py @@ -0,0 +1,138 @@ +import pytest + +import vak.config.trainer + + +class TestTrainerConfig: + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'accelerator': 'cpu', + }, + { + 'accelerator': 'gpu', + 'devices': [0], + }, + { + 'accelerator': 'gpu', + 'devices': [1], + }, + { + 'accelerator': 'cpu', + 'devices': 1, + }, + { + 'devices': 1, + }, + ] + ) + def test_init(self, config_dict): + trainer_config = vak.config.trainer.TrainerConfig(**config_dict) + + assert isinstance(trainer_config, vak.config.trainer.TrainerConfig) + if 'accelerator' in config_dict: + assert getattr(trainer_config, 'accelerator') == config_dict['accelerator'] + else: + # TODO: mock `accelerator.get_default` here, return either 'cpu' or 'gpu' + assert getattr(trainer_config, 'accelerator') == vak.common.accelerator.get_default() + if 'devices' in config_dict: + assert getattr(trainer_config, 'devices') == config_dict['devices'] + else: + if 'accelerator' == 'cpu': + assert getattr(trainer_config, 'devices') == 1 + elif 'accelerator' in ('gpu', 'tpu', 'ipu'): + assert getattr(trainer_config, 'devices') == [0] + + @pytest.mark.parametrize( + 'config_dict, expected_exception', + [ + # 'accelerator' can't be 'auto', breaks train/eval/prep/learncurve functions + ( + { + 'accelerator': 'auto', + }, + ValueError, + ), + # throws a device because parallel across GPUs won't work right now + ( + { + 'accelerator': 'gpu', + 'devices': [0, 1], + }, + ValueError, + ), + # 'devices' can't be -1, breaks train/eval/prep/learncurve functions + ( + { + 'accelerator': 'gpu', + 'devices': -1, + }, + ValueError, + ), + ( + { + 'accelerator': 'gpu', + 'devices': 'auto', + }, + ValueError, + ), + # when accelerator is CPU, devices must be int gt 0 + ( + { + 'accelerator': 'cpu', + 'devices': -1, + }, + ValueError, + ) + ] + ) + def test_init_raises(self, config_dict, expected_exception): + with pytest.raises(expected_exception): + vak.config.trainer.TrainerConfig(**config_dict) + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'accelerator': 'cpu', + }, + { + 'accelerator': 'gpu', + }, + { + 'accelerator': 'gpu', + 'devices': [0], + }, + { + 'accelerator': 'gpu', + 'devices': [1], + }, + { + 'accelerator': 'cpu', + 'devices': 1, + }, + { + 'devices': 1, + }, + ] + ) + def test_asdict(self, config_dict): + trainer_config = vak.config.trainer.TrainerConfig(**config_dict) + + trainer_config_asdict = trainer_config.asdict() + + assert isinstance(trainer_config_asdict, dict) + if 'accelerator' in config_dict: + assert trainer_config_asdict['accelerator'] == config_dict['accelerator'] + else: + # TODO: mock `accelerator.get_default` here, return either 'cpu' or 'gpu' + assert trainer_config_asdict['accelerator'] == vak.common.accelerator.get_default() + if 'devices' in config_dict: + assert trainer_config_asdict['devices'] == config_dict['devices'] + else: + if config_dict["accelerator"] == 'cpu': + assert trainer_config_asdict['devices'] == 1 + elif config_dict["accelerator"] == 'gpu': + assert trainer_config_asdict['devices'] == [0] diff --git a/tests/test_eval/test_eval.py b/tests/test_eval/test_eval.py index 7a874afb9..21d543828 100644 --- a/tests/test_eval/test_eval.py +++ b/tests/test_eval/test_eval.py @@ -31,7 +31,7 @@ def test_eval( keys_to_change = [ {"table": "eval", "key": "output_dir", "value": str(output_dir)}, - {"table": "eval", "key": "device", "value": 'cpu'}, + {"table": "eval", "key": "trainer", "value": {"accelerator": "cpu", "devices": 1}}, ] toml_path = specific_config_toml_path( @@ -51,13 +51,13 @@ def test_eval( vak.eval.eval( model_config=cfg.eval.model.asdict(), dataset_config=cfg.eval.dataset.asdict(), + trainer_config=cfg.eval.trainer.asdict(), checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, num_workers=cfg.eval.num_workers, batch_size=cfg.eval.batch_size, spect_scaler_path=cfg.eval.spect_scaler_path, - device=cfg.eval.device, post_tfm_kwargs=cfg.eval.post_tfm_kwargs, ) diff --git a/tests/test_eval/test_frame_classification.py b/tests/test_eval/test_frame_classification.py index ee55825a5..40eb04c07 100644 --- a/tests/test_eval/test_frame_classification.py +++ b/tests/test_eval/test_frame_classification.py @@ -46,7 +46,7 @@ def test_eval_frame_classification_model( annot_format, specific_config_toml_path, tmp_path, - device, + trainer_table, post_tfm_kwargs ): output_dir = tmp_path.joinpath( @@ -56,7 +56,7 @@ def test_eval_frame_classification_model( keys_to_change = [ {"table": "eval", "key": "output_dir", "value": str(output_dir)}, - {"table": "eval", "key": "device", "value": device}, + {"table": "eval", "key": "trainer", "value": trainer_table}, ] toml_path = specific_config_toml_path( @@ -72,12 +72,12 @@ def test_eval_frame_classification_model( vak.eval.frame_classification.eval_frame_classification_model( model_config=cfg.eval.model.asdict(), dataset_config=cfg.eval.dataset.asdict(), + trainer_config=cfg.eval.trainer.asdict(), checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, num_workers=cfg.eval.num_workers, spect_scaler_path=cfg.eval.spect_scaler_path, - device=cfg.eval.device, post_tfm_kwargs=post_tfm_kwargs, ) @@ -96,7 +96,7 @@ def test_eval_frame_classification_model_raises_file_not_found( path_option_to_change, specific_config_toml_path, tmp_path, - device + trainer_table ): """Test that core.eval raises FileNotFoundError when one of the following does not exist: @@ -109,7 +109,7 @@ def test_eval_frame_classification_model_raises_file_not_found( keys_to_change = [ {"table": "eval", "key": "output_dir", "value": str(output_dir)}, - {"table": "eval", "key": "device", "value": device}, + {"table": "eval", "key": "trainer", "value": trainer_table}, path_option_to_change, ] @@ -126,12 +126,12 @@ def test_eval_frame_classification_model_raises_file_not_found( vak.eval.frame_classification.eval_frame_classification_model( model_config=cfg.eval.model.asdict(), dataset_config=cfg.eval.dataset.asdict(), + trainer_config=cfg.eval.trainer.asdict(), checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, num_workers=cfg.eval.num_workers, spect_scaler_path=cfg.eval.spect_scaler_path, - device=cfg.eval.device, ) @@ -145,7 +145,7 @@ def test_eval_frame_classification_model_raises_file_not_found( def test_eval_frame_classification_model_raises_not_a_directory( path_option_to_change, specific_config_toml_path, - device, + trainer_table, tmp_path, ): """Test that core.eval raises NotADirectory @@ -153,7 +153,7 @@ def test_eval_frame_classification_model_raises_not_a_directory( """ keys_to_change = [ path_option_to_change, - {"table": "eval", "key": "device", "value": device}, + {"table": "eval", "key": "trainer", "value": trainer_table}, ] if path_option_to_change["key"] != "output_dir": @@ -180,10 +180,10 @@ def test_eval_frame_classification_model_raises_not_a_directory( vak.eval.frame_classification.eval_frame_classification_model( model_config=cfg.eval.model.asdict(), dataset_config=cfg.eval.dataset.asdict(), + trainer_config=cfg.eval.trainer.asdict(), checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, num_workers=cfg.eval.num_workers, spect_scaler_path=cfg.eval.spect_scaler_path, - device=cfg.eval.device, ) diff --git a/tests/test_eval/test_parametric_umap.py b/tests/test_eval/test_parametric_umap.py index 4c6c7e573..b1744291d 100644 --- a/tests/test_eval/test_parametric_umap.py +++ b/tests/test_eval/test_parametric_umap.py @@ -25,7 +25,7 @@ def test_eval_parametric_umap_model( annot_format, specific_config_toml_path, tmp_path, - device, + trainer_table, ): output_dir = tmp_path.joinpath( f"test_eval_{audio_format}_{spect_format}_{annot_format}" @@ -34,7 +34,7 @@ def test_eval_parametric_umap_model( keys_to_change = [ {"table": "eval", "key": "output_dir", "value": str(output_dir)}, - {"table": "eval", "key": "device", "value": device}, + {"table": "eval", "key": "trainer", "value": trainer_table}, ] toml_path = specific_config_toml_path( @@ -50,11 +50,11 @@ def test_eval_parametric_umap_model( vak.eval.parametric_umap.eval_parametric_umap_model( model_config=cfg.eval.model.asdict(), dataset_config=cfg.eval.dataset.asdict(), + trainer_config=cfg.eval.trainer.asdict(), checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, num_workers=cfg.eval.num_workers, - device=cfg.eval.device, ) assert_eval_output_matches_expected(cfg.eval.model.name, output_dir) @@ -70,7 +70,7 @@ def test_eval_frame_classification_model_raises_file_not_found( path_option_to_change, specific_config_toml_path, tmp_path, - device + trainer_table ): """Test that :func:`vak.eval.parametric_umap.eval_parametric_umap_model` raises FileNotFoundError when expected""" @@ -81,7 +81,7 @@ def test_eval_frame_classification_model_raises_file_not_found( keys_to_change = [ {"table": "eval", "key": "output_dir", "value": str(output_dir)}, - {"table": "eval", "key": "device", "value": device}, + {"table": "eval", "key": "trainer", "value": trainer_table}, path_option_to_change, ] @@ -98,11 +98,11 @@ def test_eval_frame_classification_model_raises_file_not_found( vak.eval.parametric_umap.eval_parametric_umap_model( model_config=cfg.eval.model.asdict(), dataset_config=cfg.eval.dataset.asdict(), + trainer_config=cfg.eval.trainer.asdict(), checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, num_workers=cfg.eval.num_workers, - device=cfg.eval.device, ) @@ -116,14 +116,14 @@ def test_eval_frame_classification_model_raises_file_not_found( def test_eval_frame_classification_model_raises_not_a_directory( path_option_to_change, specific_config_toml_path, - device, + trainer_table, tmp_path, ): """Test that :func:`vak.eval.parametric_umap.eval_parametric_umap_model` raises NotADirectoryError when expected""" keys_to_change = [ path_option_to_change, - {"table": "eval", "key": "device", "value": device}, + {"table": "eval", "key": "trainer", "value": trainer_table}, ] if path_option_to_change["key"] != "output_dir": @@ -150,9 +150,9 @@ def test_eval_frame_classification_model_raises_not_a_directory( vak.eval.parametric_umap.eval_parametric_umap_model( model_config=cfg.eval.model.asdict(), dataset_config=cfg.eval.dataset.asdict(), + trainer_config=cfg.eval.trainer.asdict(), checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, num_workers=cfg.eval.num_workers, - device=cfg.eval.device, ) diff --git a/tests/test_learncurve/test_frame_classification.py b/tests/test_learncurve/test_frame_classification.py index c88f4ebf4..ddaa99f6a 100644 --- a/tests/test_learncurve/test_frame_classification.py +++ b/tests/test_learncurve/test_frame_classification.py @@ -51,8 +51,8 @@ def assert_learncurve_output_matches_expected(cfg, model_name, results_path): ] ) def test_learning_curve_for_frame_classification_model( - model_name, audio_format, annot_format, specific_config_toml_path, tmp_path, device): - keys_to_change = {"table": "learncurve", "key": "device", "value": device} + model_name, audio_format, annot_format, specific_config_toml_path, tmp_path, trainer_table): + keys_to_change = {"table": "learncurve", "key": "trainer", "value": trainer_table} toml_path = specific_config_toml_path( config_type="learncurve", @@ -69,6 +69,7 @@ def test_learning_curve_for_frame_classification_model( vak.learncurve.frame_classification.learning_curve_for_frame_classification_model( model_config=cfg.learncurve.model.asdict(), dataset_config=cfg.learncurve.dataset.asdict(), + trainer_config=cfg.learncurve.trainer.asdict(), batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, num_workers=cfg.learncurve.num_workers, @@ -79,7 +80,6 @@ def test_learning_curve_for_frame_classification_model( val_step=cfg.learncurve.val_step, ckpt_step=cfg.learncurve.ckpt_step, patience=cfg.learncurve.patience, - device=cfg.learncurve.device, ) assert_learncurve_output_matches_expected(cfg, cfg.learncurve.model.name, results_path) @@ -94,13 +94,13 @@ def test_learning_curve_for_frame_classification_model( ) def test_learncurve_raises_not_a_directory(dir_option_to_change, specific_config_toml_path, - tmp_path, device): + tmp_path, trainer_table): """Test that core.learncurve.learning_curve raises NotADirectoryError when the following directories do not exist: results_path, previous_run_path, dataset_path """ keys_to_change = [ - {"table": "learncurve", "key": "device", "value": device}, + {"table": "learncurve", "key": "trainer", "value": trainer_table}, dir_option_to_change ] toml_path = specific_config_toml_path( @@ -118,6 +118,7 @@ def test_learncurve_raises_not_a_directory(dir_option_to_change, vak.learncurve.frame_classification.learning_curve_for_frame_classification_model( model_config=cfg.learncurve.model.asdict(), dataset_config=cfg.learncurve.dataset.asdict(), + trainer_config=cfg.learncurve.trainer.asdict(), batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, num_workers=cfg.learncurve.num_workers, @@ -128,5 +129,4 @@ def test_learncurve_raises_not_a_directory(dir_option_to_change, val_step=cfg.learncurve.val_step, ckpt_step=cfg.learncurve.ckpt_step, patience=cfg.learncurve.patience, - device=cfg.learncurve.device, ) diff --git a/tests/test_models/test_tweetynet.py b/tests/test_models/test_tweetynet.py index 6c06c2029..ce219ea9a 100644 --- a/tests/test_models/test_tweetynet.py +++ b/tests/test_models/test_tweetynet.py @@ -1,7 +1,7 @@ import itertools import pytest -import pytorch_lightning as lightning +import lightning import vak @@ -37,7 +37,7 @@ def test_model_is_decorated(self): assert issubclass(vak.models.TweetyNet, vak.models.base.Model) assert issubclass(vak.models.TweetyNet, - lightning.LightningModule) + lightning.pytorch.LightningModule) @pytest.mark.parametrize( 'labelmap, input_shape', diff --git a/tests/test_predict/test_frame_classification.py b/tests/test_predict/test_frame_classification.py index 5b8902f9d..f0507a016 100644 --- a/tests/test_predict/test_frame_classification.py +++ b/tests/test_predict/test_frame_classification.py @@ -30,7 +30,7 @@ def test_predict_with_frame_classification_model( save_net_outputs, specific_config_toml_path, tmp_path, - device, + trainer_table, ): output_dir = tmp_path.joinpath( f"test_predict_{audio_format}_{spect_format}_{annot_format}" @@ -39,7 +39,7 @@ def test_predict_with_frame_classification_model( keys_to_change = [ {"table": "predict", "key": "output_dir", "value": str(output_dir)}, - {"table": "predict", "key": "device", "value": device}, + {"table": "predict", "key": "trainer", "value": trainer_table}, {"table": "predict", "key": "save_net_outputs", "value": save_net_outputs}, ] toml_path = specific_config_toml_path( @@ -54,12 +54,12 @@ def test_predict_with_frame_classification_model( vak.predict.frame_classification.predict_with_frame_classification_model( model_config=cfg.predict.model.asdict(), dataset_config=cfg.predict.dataset.asdict(), + trainer_config=cfg.predict.trainer.asdict(), checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, - device=cfg.predict.device, annot_csv_filename=cfg.predict.annot_csv_filename, output_dir=cfg.predict.output_dir, min_segment_dur=cfg.predict.min_segment_dur, @@ -98,7 +98,7 @@ def test_predict_with_frame_classification_model_raises_file_not_found( path_option_to_change, specific_config_toml_path, tmp_path, - device + trainer_table ): """Test that core.eval raises FileNotFoundError when `dataset_path` does not exist.""" @@ -109,7 +109,7 @@ def test_predict_with_frame_classification_model_raises_file_not_found( keys_to_change = [ {"table": "predict", "key": "output_dir", "value": str(output_dir)}, - {"table": "predict", "key": "device", "value": device}, + {"table": "predict", "key": "trainer", "value": trainer_table}, path_option_to_change, ] toml_path = specific_config_toml_path( @@ -125,12 +125,12 @@ def test_predict_with_frame_classification_model_raises_file_not_found( vak.predict.frame_classification.predict_with_frame_classification_model( model_config=cfg.predict.model.asdict(), dataset_config=cfg.predict.dataset.asdict(), + trainer_config=cfg.predict.trainer.asdict(), checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, - device=cfg.predict.device, annot_csv_filename=cfg.predict.annot_csv_filename, output_dir=cfg.predict.output_dir, min_segment_dur=cfg.predict.min_segment_dur, @@ -149,7 +149,7 @@ def test_predict_with_frame_classification_model_raises_file_not_found( def test_predict_with_frame_classification_model_raises_not_a_directory( path_option_to_change, specific_config_toml_path, - device, + trainer_table, tmp_path, ): """Test that core.eval raises NotADirectory @@ -157,7 +157,7 @@ def test_predict_with_frame_classification_model_raises_not_a_directory( """ keys_to_change = [ path_option_to_change, - {"table": "predict", "key": "device", "value": device}, + {"table": "predict", "key": "trainer", "value": trainer_table}, ] if path_option_to_change["key"] != "output_dir": @@ -184,12 +184,12 @@ def test_predict_with_frame_classification_model_raises_not_a_directory( vak.predict.frame_classification.predict_with_frame_classification_model( model_config=cfg.predict.model.asdict(), dataset_config=cfg.predict.dataset.asdict(), + trainer_config=cfg.predict.trainer.asdict(), checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, - device=cfg.predict.device, annot_csv_filename=cfg.predict.annot_csv_filename, output_dir=cfg.predict.output_dir, min_segment_dur=cfg.predict.min_segment_dur, diff --git a/tests/test_predict/test_predict.py b/tests/test_predict/test_predict.py index 820613ed4..f7a7f0f7a 100644 --- a/tests/test_predict/test_predict.py +++ b/tests/test_predict/test_predict.py @@ -28,7 +28,7 @@ def test_predict( keys_to_change = [ {"table": "predict", "key": "output_dir", "value": str(output_dir)}, - {"table": "predict", "key": "device", "value": 'cpu'}, + {"table": "predict", "key": "trainer", "value": {"accelerator": "cpu", "devices": 1}}, ] toml_path = specific_config_toml_path( @@ -47,12 +47,12 @@ def test_predict( vak.predict.predict( model_config=cfg.predict.model.asdict(), dataset_config=cfg.predict.dataset.asdict(), + trainer_config=cfg.predict.trainer.asdict(), checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, - device=cfg.predict.device, annot_csv_filename=cfg.predict.annot_csv_filename, output_dir=cfg.predict.output_dir, min_segment_dur=cfg.predict.min_segment_dur, diff --git a/tests/test_prep/test_frame_classification/test_learncurve.py b/tests/test_prep/test_frame_classification/test_learncurve.py index ef3105144..0030e01d0 100644 --- a/tests/test_prep/test_frame_classification/test_learncurve.py +++ b/tests/test_prep/test_frame_classification/test_learncurve.py @@ -18,7 +18,7 @@ ] ) def test_make_index_vectors_for_each_subsets( - model_name, audio_format, annot_format, input_type, specific_config_toml_path, device, tmp_path, + model_name, audio_format, annot_format, input_type, specific_config_toml_path, trainer_table, tmp_path, ): root_results_dir = tmp_path.joinpath("tmp_root_results_dir") root_results_dir.mkdir() @@ -130,7 +130,7 @@ def test_make_index_vectors_for_each_subsets( ] ) def test_make_subsets_from_dataset_df( - model_name, audio_format, annot_format, input_type, specific_config_toml_path, device, tmp_path, + model_name, audio_format, annot_format, input_type, specific_config_toml_path, trainer_table, tmp_path, ): root_results_dir = tmp_path.joinpath("tmp_root_results_dir") root_results_dir.mkdir() diff --git a/tests/test_train/test_frame_classification.py b/tests/test_train/test_frame_classification.py index f4e50ef46..e9d06aa98 100644 --- a/tests/test_train/test_frame_classification.py +++ b/tests/test_train/test_frame_classification.py @@ -42,12 +42,12 @@ def assert_train_output_matches_expected(cfg: vak.config.config.Config, model_na ], ) def test_train_frame_classification_model( - model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, device + model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, trainer_table ): results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() keys_to_change = [ - {"table": "train", "key": "device", "value": device}, + {"table": "train", "key": "trainer", "value": trainer_table}, {"table": "train", "key": "root_results_dir", "value": str(results_path)} ] toml_path = specific_config_toml_path( @@ -63,6 +63,7 @@ def test_train_frame_classification_model( vak.train.frame_classification.train_frame_classification_model( model_config=cfg.train.model.asdict(), dataset_config=cfg.train.dataset.asdict(), + trainer_config=cfg.train.trainer.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -74,7 +75,6 @@ def test_train_frame_classification_model( val_step=cfg.train.val_step, ckpt_step=cfg.train.ckpt_step, patience=cfg.train.patience, - device=cfg.train.device, ) assert_train_output_matches_expected(cfg, cfg.train.model.name, results_path) @@ -89,12 +89,12 @@ def test_train_frame_classification_model( ], ) def test_continue_training( - model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, device + model_name, audio_format, spect_format, annot_format, specific_config_toml_path, tmp_path, trainer_table ): results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() keys_to_change = [ - {"table": "train", "key": "device", "value": device}, + {"table": "train", "key": "trainer", "value": trainer_table}, {"table": "train", "key": "root_results_dir", "value": str(results_path)} ] toml_path = specific_config_toml_path( @@ -110,6 +110,7 @@ def test_continue_training( vak.train.frame_classification.train_frame_classification_model( model_config=cfg.train.model.asdict(), dataset_config=cfg.train.dataset.asdict(), + trainer_config=cfg.train.trainer.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -121,7 +122,6 @@ def test_continue_training( val_step=cfg.train.val_step, ckpt_step=cfg.train.ckpt_step, patience=cfg.train.patience, - device=cfg.train.device, ) assert_train_output_matches_expected(cfg, cfg.train.model.name, results_path) @@ -135,14 +135,14 @@ def test_continue_training( ] ) def test_train_raises_file_not_found( - path_option_to_change, specific_config_toml_path, tmp_path, device + path_option_to_change, specific_config_toml_path, tmp_path, trainer_table ): """Test that pre-conditions in `vak.train` raise FileNotFoundError when one of the following does not exist: checkpoint_path, dataset_path, spect_scaler_path """ keys_to_change = [ - {"table": "train", "key": "device", "value": device}, + {"table": "train", "key": "trainer", "value": trainer_table}, path_option_to_change ] toml_path = specific_config_toml_path( @@ -161,6 +161,7 @@ def test_train_raises_file_not_found( vak.train.frame_classification.train_frame_classification_model( model_config=cfg.train.model.asdict(), dataset_config=cfg.train.dataset.asdict(), + trainer_config=cfg.train.trainer.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -172,7 +173,6 @@ def test_train_raises_file_not_found( val_step=cfg.train.val_step, ckpt_step=cfg.train.ckpt_step, patience=cfg.train.patience, - device=cfg.train.device, ) @@ -184,14 +184,14 @@ def test_train_raises_file_not_found( ] ) def test_train_raises_not_a_directory( - path_option_to_change, specific_config_toml_path, device, tmp_path + path_option_to_change, specific_config_toml_path, trainer_table, tmp_path ): """Test that core.train raises NotADirectory when directory does not exist """ keys_to_change = [ path_option_to_change, - {"table": "train", "key": "device", "value": device}, + {"table": "train", "key": "trainer", "value": trainer_table}, ] toml_path = specific_config_toml_path( @@ -211,6 +211,7 @@ def test_train_raises_not_a_directory( vak.train.frame_classification.train_frame_classification_model( model_config=cfg.train.model.asdict(), dataset_config=cfg.train.dataset.asdict(), + trainer_config=cfg.train.trainer.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -222,5 +223,4 @@ def test_train_raises_not_a_directory( val_step=cfg.train.val_step, ckpt_step=cfg.train.ckpt_step, patience=cfg.train.patience, - device=cfg.train.device, ) diff --git a/tests/test_train/test_parametric_umap.py b/tests/test_train/test_parametric_umap.py index 509078c21..247009a8d 100644 --- a/tests/test_train/test_parametric_umap.py +++ b/tests/test_train/test_parametric_umap.py @@ -35,12 +35,12 @@ def assert_train_output_matches_expected(cfg: vak.config.config.Config, model_na ) def test_train_parametric_umap_model( model_name, audio_format, spect_format, annot_format, - specific_config_toml_path, tmp_path, device + specific_config_toml_path, tmp_path, trainer_table ): results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() keys_to_change = [ - {"table": "train", "key": "device", "value": device}, + {"table": "train", "key": "trainer", "value": trainer_table}, {"table": "train", "key": "root_results_dir", "value": str(results_path)} ] toml_path = specific_config_toml_path( @@ -56,6 +56,7 @@ def test_train_parametric_umap_model( vak.train.parametric_umap.train_parametric_umap_model( model_config=cfg.train.model.asdict(), dataset_config=cfg.train.dataset.asdict(), + trainer_config=cfg.train.trainer.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -64,7 +65,6 @@ def test_train_parametric_umap_model( shuffle=cfg.train.shuffle, val_step=cfg.train.val_step, ckpt_step=cfg.train.ckpt_step, - device=cfg.train.device, ) assert_train_output_matches_expected(cfg, cfg.train.model.name, results_path) @@ -77,14 +77,14 @@ def test_train_parametric_umap_model( ] ) def test_train_parametric_umap_model_raises_file_not_found( - path_option_to_change, specific_config_toml_path, tmp_path, device + path_option_to_change, specific_config_toml_path, tmp_path, trainer_table ): """Test that pre-conditions in :func:`vak.train.parametric_umap.train_parametric_umap_model` raise FileNotFoundError when one of the following does not exist: checkpoint_path, dataset_path """ keys_to_change = [ - {"table": "train", "key": "device", "value": device}, + {"table": "train", "key": "trainer", "value": trainer_table}, path_option_to_change ] toml_path = specific_config_toml_path( @@ -103,6 +103,7 @@ def test_train_parametric_umap_model_raises_file_not_found( vak.train.parametric_umap.train_parametric_umap_model( model_config=cfg.train.model.asdict(), dataset_config=cfg.train.dataset.asdict(), + trainer_config=cfg.train.trainer.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -111,7 +112,6 @@ def test_train_parametric_umap_model_raises_file_not_found( shuffle=cfg.train.shuffle, val_step=cfg.train.val_step, ckpt_step=cfg.train.ckpt_step, - device=cfg.train.device, ) @@ -123,14 +123,14 @@ def test_train_parametric_umap_model_raises_file_not_found( ] ) def test_train_parametric_umap_model_raises_not_a_directory( - path_option_to_change, specific_config_toml_path, device, tmp_path + path_option_to_change, specific_config_toml_path, trainer_table, tmp_path ): """Test that core.train raises NotADirectory when directory does not exist """ keys_to_change = [ path_option_to_change, - {"table": "train", "key": "device", "value": device}, + {"table": "train", "key": "trainer", "value": trainer_table}, ] toml_path = specific_config_toml_path( @@ -151,6 +151,7 @@ def test_train_parametric_umap_model_raises_not_a_directory( vak.train.parametric_umap.train_parametric_umap_model( model_config=model_config, dataset_config=cfg.train.dataset.asdict(), + trainer_config=cfg.train.trainer.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -159,5 +160,4 @@ def test_train_parametric_umap_model_raises_not_a_directory( shuffle=cfg.train.shuffle, val_step=cfg.train.val_step, ckpt_step=cfg.train.ckpt_step, - device=cfg.train.device, ) diff --git a/tests/test_train/test_train.py b/tests/test_train/test_train.py index b9038007e..11886db68 100644 --- a/tests/test_train/test_train.py +++ b/tests/test_train/test_train.py @@ -35,7 +35,7 @@ def test_train( "key": "root_results_dir", "value": str(root_results_dir), }, - {"table": "train", "key": "device", "value": 'cpu'}, + {"table": "train", "key": "trainer", "value": {"accelerator": "cpu", "devices": 1}}, ] toml_path = specific_config_toml_path( @@ -55,6 +55,7 @@ def test_train( vak.train.train( model_config=cfg.train.model.asdict(), dataset_config=cfg.train.dataset.asdict(), + trainer_config=cfg.train.trainer.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -66,6 +67,5 @@ def test_train( val_step=cfg.train.val_step, ckpt_step=cfg.train.ckpt_step, patience=cfg.train.patience, - device=cfg.train.device, ) assert mock_train_function.called