Skip to content

Commit

Permalink
ENH: Add lightning.Trainer config, fix #691 #687 #742 #745 (#752)
Browse files Browse the repository at this point in the history
* WIP: Add config/trainer.py with TrainerConfig

* Rename common.device -> common.accelerator, return 'gpu' not 'cuda' if torch.cuda.is_available

* Fix config section in doc/api/index.rst

* Import trainer and TrainerConfig in src/vak/config/__init__.py, add to __all__

* Add pytorch-lightning to intersphinx in doc/conf.py

* Fix cross-ref in docstring in src/vak/prep/frame_classification/make_splits.py: :constant: -> :const:

* Make lightning a dependency, instead of pytorch_lightning; import lightning.pytorch everywhere instead of pytorch_lightning as lightning -- trying to make it so we can resolve API correctly in docstrings

* Fix in doc/api/index.rst: common.device -> common.accelerator

* Finish writing TrainerConfig class

* Add tests for TrainerConfig class

* Add trainer sub-table to all configs in tests/data_for_tests/configs

* Add trainer sub-table to all configs in doc/toml

* Add trainer sub-table in config/valid-version-1.0.toml, rename -> valid-version-1.1.toml

* Remove device key from top-level tables in config/valid-version-1.1.toml

* Remove device key from top-level tables in tests/data_for_tests/configs

* Remove 'device' key from configs in doc/toml

* Add 'trainer' attribute to EvalConfig, an instance of TrainerConfig; remove 'device' attribute

* Add 'trainer' attribute to PredictConfig, an instance of TrainerConfig; remove 'device' attribute

* Add 'trainer' attribute to TrainConfig, an instance of TrainerConfig; remove 'device' attribute

* Fix typo in docstring in src/vak/config/train.py

* Add 'trainer' attribute to LearncurveConfig, an instance of TrainerConfig; remove 'device' attribute. Also clean up docstring, removing attributes that no longer exist

* Remove device attribute from TrainConfig docstring

* Fix VALID_TOML_PATH in config/validators.py -> 'valid-version-1.1.toml'

* Fix how we instantiate TrainerConfig classes in from_config_dict method of EvalConfig/LearncurveConfig/PredictConfig/TrainConfig

* Fix typo in src/vak/config/valid-version-1.1.toml: predictor -> predict

* Fix unit tests after adding trainer attribute that is instance of TrainerConfig

* Change src/vak/train/frame_classification.py to take trainer_config argument

* Change src/vak/train/parametric_umap.py to take trainer_config argument

* Change src/vak/train/train_.py to take trainer_config argument

* Fix src/vak/cli/train.py to pass trainer_config.asdict() into vak.train.train_.train

* Replace 'device' with 'trainer_config' in vak/eval

* Fix cli.eval to pass trainer_config into eval.eval_.eval

* Replace 'device' with 'trainer_config' in vak/predict

* Fix cli.predict to pass trainer_config into predict.predict_.predict

* Replace 'device' with 'trainer_config' in vak/learncurve

* Fix cli.learncurve to pass trainer_config into learncurve.learncurve.learning_curve

* Rename/replace 'device' fixture with 'trainer' fixture in tests/

* Use config.table.trainer attribute throughout tests, remove config.table.device attribute that no longer exists

* Fix value for devices in fixtures/trainer.py: when device is 'cpu' trainer must be > 0

* Fix default devices value for when accelerator is cpu in TrainerConfig

* Fix unit tests for TrainerConfig after fixing default devices for accelerator=cpu

* Fix default value for 'devices' set to -1 in some unit tests where we over-ride config in toml file

* fixup use config.table.trainer attribute throughout tests -- missed one place in tests/test_eval/

* Add back 'device' fixture so we can use it to test Model class

* Fix unit tests in test_models/test_base.by that literally used device to put tensors on device, not to change a config

* Fix assertion in tests/test_models/test_tweetynet.py, from where we switched to using lightning as the dependency

* Fix test for DiceLoss, change trainer_type fixture back to device fixture
  • Loading branch information
NickleDave authored May 5, 2024
1 parent 1e6abe6 commit e74b92b
Show file tree
Hide file tree
Showing 71 changed files with 772 additions and 375 deletions.
6 changes: 4 additions & 2 deletions doc/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,16 @@ and dataclasses that represent tables from those files.
:recursive:

config.config
config.dataset
config.eval
config.learncurve
config.load
config.model
config.parse
config.predict
config.prep
config.spect_params
config.train
config.trainer
config.validators

Datasets
Expand Down Expand Up @@ -265,10 +267,10 @@ used by multiple other modules.
:template: module.rst
:recursive:

common.accelerator
common.annotation
common.constants
common.converters
common.device
common.files
common.labels
common.learncurve
Expand Down
7 changes: 4 additions & 3 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@
"announcement":
""" 🚧 vak version 1.0.0 is in development! 🚧
📣 Test out the alpha release: <code>pip install vak==1.0.0a3</code>. 📣
For more info, please see
<a href="https://forum.vocalpy.org/t/vak-1-0-0a1-released/55"> this forum post<a>.
For more info, please see
<a href="https://forum.vocalpy.org/t/vak-1-0-0a1-released/55"> this forum post<a>.
""",
"sidebar_hide_name": True,
"light_css_variables": {
Expand Down Expand Up @@ -217,7 +217,8 @@
"python": ("https://docs.python.org/3/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
"pytorch": ("https://pytorch.org/docs/stable/", None)
"pytorch": ("https://pytorch.org/docs/stable/", None),
"lightning": ("https://lightning.ai/docs/pytorch/stable/", None),
}

# -- Options for todo extension ----------------------------------------------
Expand Down
9 changes: 8 additions & 1 deletion doc/toml/gy6or6_eval.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ batch_size = 11
# num_workers: number of workers to use when loading data with multiprocessing
num_workers = 16
# device: name of device to run model on, one of "cuda", "cpu"
device = "cuda"

# output_dir: directory where output should be saved, as a sub-directory within `output_dir`
output_dir = "/PATH/TO/FOLDER/results/eval"
# dataset_path : path to dataset created by prep
Expand Down Expand Up @@ -72,3 +72,10 @@ window_size = 176
# Note we do not specify any options for the model, and just use the defaults
# We need to put this table here though so we know which model we are using
[vak.eval.model.TweetyNet]

# this sub-table configures the `lightning.pytorch.Trainer`
[vak.eval.trainer]
# setting to 'gpu' means "train models on 'gpu' (not 'cpu')"
accelerator = "gpu"
# use the first GPU (numbering starts from 0)
devices = [0]
9 changes: 8 additions & 1 deletion doc/toml/gy6or6_predict.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ batch_size = 1
# num_workers: number of workers to use when loading data with multiprocessing
num_workers = 4
# device: name of device to run model on, one of "cuda", "cpu"
device = "cuda"

# output_dir: directory where output should be saved, as a sub-directory within `output_dir`
output_dir = "/PATH/TO/FOLDER/results/predict"
# annot_csv_filename
Expand Down Expand Up @@ -67,3 +67,10 @@ window_size = 176
# Note we do not specify any options for the network, and just use the defaults
# We need to put this table here though, to indicate which model we are using.
[vak.predict.model.TweetyNet]

# this sub-table configures the `lightning.pytorch.Trainer`
[vak.predict.trainer]
# setting to 'gpu' means "train models on 'gpu' (not 'cpu')"
accelerator = "gpu"
# use the first GPU (numbering starts from 0)
devices = [0]
9 changes: 8 additions & 1 deletion doc/toml/gy6or6_train.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ patience = 4
# num_workers: number of workers to use when loading data with multiprocessing
num_workers = 4
# device: name of device to run model on, one of "cuda", "cpu"
device = "cuda"

# dataset_path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it

# dataset.params = parameters used for datasets
Expand All @@ -73,3 +73,10 @@ lr = 0.001
[vak.train.model.TweetyNet.network]
# hidden_size: the number of elements in the hidden state in the recurrent layer of the network
hidden_size = 256

# this sub-table configures the `lightning.pytorch.Trainer`
[vak.train.trainer]
# setting to 'gpu' means "train models on 'gpu' (not 'cpu')"
accelerator = "gpu"
# use the first GPU (numbering starts from 0)
devices = [0]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"dask[dataframe] >=2.10.1",
"evfuncs >=0.3.4",
"joblib >=0.14.1",
"pytorch-lightning >=2.0.7",
"lightning >=2.0.7",
"matplotlib >=3.3.3",
"numpy >=1.18.1",
"pynndescent >=0.5.10",
Expand Down
2 changes: 1 addition & 1 deletion src/vak/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ def eval(toml_path: str | pathlib.Path) -> None:
eval_module.eval(
model_config=cfg.eval.model.asdict(),
dataset_config=cfg.eval.dataset.asdict(),
trainer_config=cfg.eval.trainer.asdict(),
checkpoint_path=cfg.eval.checkpoint_path,
labelmap_path=cfg.eval.labelmap_path,
output_dir=cfg.eval.output_dir,
num_workers=cfg.eval.num_workers,
batch_size=cfg.eval.batch_size,
spect_scaler_path=cfg.eval.spect_scaler_path,
device=cfg.eval.device,
post_tfm_kwargs=cfg.eval.post_tfm_kwargs,
)
2 changes: 1 addition & 1 deletion src/vak/cli/learncurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def learning_curve(toml_path):
learncurve.learning_curve(
model_config=cfg.learncurve.model.asdict(),
dataset_config=cfg.learncurve.dataset.asdict(),
trainer_config=cfg.learncurve.trainer.asdict(),
batch_size=cfg.learncurve.batch_size,
num_epochs=cfg.learncurve.num_epochs,
num_workers=cfg.learncurve.num_workers,
Expand All @@ -65,5 +66,4 @@ def learning_curve(toml_path):
val_step=cfg.learncurve.val_step,
ckpt_step=cfg.learncurve.ckpt_step,
patience=cfg.learncurve.patience,
device=cfg.learncurve.device,
)
2 changes: 1 addition & 1 deletion src/vak/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ def predict(toml_path):
predict_module.predict(
model_config=cfg.predict.model.asdict(),
dataset_config=cfg.predict.dataset.asdict(),
trainer_config=cfg.predict.trainer.asdict(),
checkpoint_path=cfg.predict.checkpoint_path,
labelmap_path=cfg.predict.labelmap_path,
num_workers=cfg.predict.num_workers,
timebins_key=cfg.prep.spect_params.timebins_key,
spect_scaler_path=cfg.predict.spect_scaler_path,
device=cfg.predict.device,
annot_csv_filename=cfg.predict.annot_csv_filename,
output_dir=cfg.predict.output_dir,
min_segment_dur=cfg.predict.min_segment_dur,
Expand Down
2 changes: 1 addition & 1 deletion src/vak/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def train(toml_path):
train_module.train(
model_config=cfg.train.model.asdict(),
dataset_config=cfg.train.dataset.asdict(),
trainer_config=cfg.train.trainer.asdict(),
batch_size=cfg.train.batch_size,
num_epochs=cfg.train.num_epochs,
num_workers=cfg.train.num_workers,
Expand All @@ -66,5 +67,4 @@ def train(toml_path):
val_step=cfg.train.val_step,
ckpt_step=cfg.train.ckpt_step,
patience=cfg.train.patience,
device=cfg.train.device,
)
4 changes: 2 additions & 2 deletions src/vak/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
"""

from . import (
accelerator,
annotation,
constants,
converters,
device,
files,
labels,
learncurve,
Expand All @@ -30,7 +30,7 @@
"annotation",
"constants",
"converters",
"device",
"accelerator",
"files",
"labels",
"learncurve",
Expand Down
16 changes: 16 additions & 0 deletions src/vak/common/accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch


def get_default() -> str:
"""Get default `accelerator` for :class:`lightning.pytorch.Trainer`.
Returns
-------
accelerator : str
Will be ``'gpu'`` if :func:`torch.cuda.is_available`
is ``True``, and ``'cpu'`` if not.
"""
if torch.cuda.is_available():
return "gpu"
else:
return "cpu"
16 changes: 0 additions & 16 deletions src/vak/common/device.py

This file was deleted.

48 changes: 31 additions & 17 deletions src/vak/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import pathlib

import pytorch_lightning as lightning
import lightning


def get_default_train_callbacks(
ckpt_root: str | pathlib.Path,
ckpt_step: int,
patience: int,
):
ckpt_callback = lightning.callbacks.ModelCheckpoint(
ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint(
dirpath=ckpt_root,
filename="checkpoint",
every_n_train_steps=ckpt_step,
Expand All @@ -20,7 +20,7 @@ def get_default_train_callbacks(
ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint"
ckpt_callback.FILE_EXTENSION = ".pt"

val_ckpt_callback = lightning.callbacks.ModelCheckpoint(
val_ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint(
monitor="val_acc",
dirpath=ckpt_root,
save_top_k=1,
Expand All @@ -31,7 +31,7 @@ def get_default_train_callbacks(
)
val_ckpt_callback.FILE_EXTENSION = ".pt"

early_stopping = lightning.callbacks.EarlyStopping(
early_stopping = lightning.pytorch.callbacks.EarlyStopping(
mode="max",
monitor="val_acc",
patience=patience,
Expand All @@ -42,33 +42,47 @@ def get_default_train_callbacks(


def get_default_trainer(
accelerator: str,
devices: int | list[int],
max_steps: int,
log_save_dir: str | pathlib.Path,
val_step: int,
default_callback_kwargs: dict | None = None,
device: str = "cuda",
) -> lightning.Trainer:
"""Returns an instance of ``lightning.Trainer``
) -> lightning.pytorch.Trainer:
"""Returns an instance of :class:`lightning.pytorch.Trainer`
with a default set of callbacks.
Used by ``vak.core`` functions."""
Used by :func:`vak.train.frame_classification`.
The default set of callbacks is provided by
:func:`get_default_train_callbacks`.
Parameters
----------
accelerator : str
devices : int, list of int
max_steps : int
log_save_dir : str, pathlib.Path
val_step : int
default_callback_kwargs : dict, optional
Returns
-------
trainer : lightning.pytorch.Trainer
"""
if default_callback_kwargs:
callbacks = get_default_train_callbacks(**default_callback_kwargs)
else:
callbacks = None

# TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691
if device == "cuda":
accelerator = "gpu"
else:
accelerator = "auto"

logger = lightning.loggers.TensorBoardLogger(save_dir=log_save_dir)
logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=log_save_dir)

trainer = lightning.Trainer(
trainer = lightning.pytorch.Trainer(
accelerator=accelerator,
devices=devices,
callbacks=callbacks,
val_check_interval=val_step,
max_steps=max_steps,
accelerator=accelerator,
logger=logger,
)
return trainer
5 changes: 5 additions & 0 deletions src/vak/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
prep,
spect_params,
train,
trainer,
validators,
)
from .config import Config
Expand All @@ -22,6 +23,8 @@
from .prep import PrepConfig
from .spect_params import SpectParamsConfig
from .train import TrainConfig
from .trainer import TrainerConfig


__all__ = [
"config",
Expand All @@ -34,6 +37,7 @@
"prep",
"spect_params",
"train",
"trainer",
"validators",
"Config",
"DatasetConfig",
Expand All @@ -44,4 +48,5 @@
"PrepConfig",
"SpectParamsConfig",
"TrainConfig",
"TrainerConfig",
]
Loading

0 comments on commit e74b92b

Please sign in to comment.