Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add lightning.Trainer config, fix #691 #687 #742 #745 #752

Merged
merged 47 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
b7f43a9
WIP: Add config/trainer.py with TrainerConfig
NickleDave May 5, 2024
54b8718
Rename common.device -> common.accelerator, return 'gpu' not 'cuda' i…
NickleDave May 5, 2024
86b40c2
Fix config section in doc/api/index.rst
NickleDave May 5, 2024
cf8f79c
Import trainer and TrainerConfig in src/vak/config/__init__.py, add t…
NickleDave May 5, 2024
5b35753
Add pytorch-lightning to intersphinx in doc/conf.py
NickleDave May 5, 2024
8549cb4
Fix cross-ref in docstring in src/vak/prep/frame_classification/make_…
NickleDave May 5, 2024
23b78a3
Make lightning a dependency, instead of pytorch_lightning; import lig…
NickleDave May 5, 2024
9d06234
Fix in doc/api/index.rst: common.device -> common.accelerator
NickleDave May 5, 2024
ad0cf3e
Finish writing TrainerConfig class
NickleDave May 5, 2024
9fd60d2
Add tests for TrainerConfig class
NickleDave May 5, 2024
64ee588
Add trainer sub-table to all configs in tests/data_for_tests/configs
NickleDave May 5, 2024
68efbc0
Add trainer sub-table to all configs in doc/toml
NickleDave May 5, 2024
a6337f5
Add trainer sub-table in config/valid-version-1.0.toml, rename -> val…
NickleDave May 5, 2024
7094a6d
Remove device key from top-level tables in config/valid-version-1.1.toml
NickleDave May 5, 2024
c14348e
Remove device key from top-level tables in tests/data_for_tests/configs
NickleDave May 5, 2024
5d378e9
Remove 'device' key from configs in doc/toml
NickleDave May 5, 2024
cfd1972
Add 'trainer' attribute to EvalConfig, an instance of TrainerConfig; …
NickleDave May 5, 2024
96df71e
Add 'trainer' attribute to PredictConfig, an instance of TrainerConfi…
NickleDave May 5, 2024
7c10814
Add 'trainer' attribute to TrainConfig, an instance of TrainerConfig;…
NickleDave May 5, 2024
fd12f97
Fix typo in docstring in src/vak/config/train.py
NickleDave May 5, 2024
7253c07
Add 'trainer' attribute to LearncurveConfig, an instance of TrainerCo…
NickleDave May 5, 2024
ef05f6f
Remove device attribute from TrainConfig docstring
NickleDave May 5, 2024
81b9a7d
Fix VALID_TOML_PATH in config/validators.py -> 'valid-version-1.1.toml'
NickleDave May 5, 2024
5cd078d
Fix how we instantiate TrainerConfig classes in from_config_dict meth…
NickleDave May 5, 2024
436623a
Fix typo in src/vak/config/valid-version-1.1.toml: predictor -> predict
NickleDave May 5, 2024
bd5f7b0
Fix unit tests after adding trainer attribute that is instance of Tra…
NickleDave May 5, 2024
2930eee
Change src/vak/train/frame_classification.py to take trainer_config a…
NickleDave May 5, 2024
13a30ec
Change src/vak/train/parametric_umap.py to take trainer_config argument
NickleDave May 5, 2024
07031e1
Change src/vak/train/train_.py to take trainer_config argument
NickleDave May 5, 2024
f9400df
Fix src/vak/cli/train.py to pass trainer_config.asdict() into vak.tra…
NickleDave May 5, 2024
5ac40e0
Replace 'device' with 'trainer_config' in vak/eval
NickleDave May 5, 2024
46ed723
Fix cli.eval to pass trainer_config into eval.eval_.eval
NickleDave May 5, 2024
994b69e
Replace 'device' with 'trainer_config' in vak/predict
NickleDave May 5, 2024
9a09ba1
Fix cli.predict to pass trainer_config into predict.predict_.predict
NickleDave May 5, 2024
a50eb34
Replace 'device' with 'trainer_config' in vak/learncurve
NickleDave May 5, 2024
64e9331
Fix cli.learncurve to pass trainer_config into learncurve.learncurve.…
NickleDave May 5, 2024
22e9faa
Rename/replace 'device' fixture with 'trainer' fixture in tests/
NickleDave May 5, 2024
9a157f2
Use config.table.trainer attribute throughout tests, remove config.ta…
NickleDave May 5, 2024
8e96836
Fix value for devices in fixtures/trainer.py: when device is 'cpu' tr…
NickleDave May 5, 2024
134dc0c
Fix default devices value for when accelerator is cpu in TrainerConfig
NickleDave May 5, 2024
3c4ee20
Fix unit tests for TrainerConfig after fixing default devices for acc…
NickleDave May 5, 2024
d92af7a
Fix default value for 'devices' set to -1 in some unit tests where we…
NickleDave May 5, 2024
ab12771
fixup use config.table.trainer attribute throughout tests -- missed o…
NickleDave May 5, 2024
51b888d
Add back 'device' fixture so we can use it to test Model class
NickleDave May 5, 2024
5a243a9
Fix unit tests in test_models/test_base.by that literally used device…
NickleDave May 5, 2024
151abe2
Fix assertion in tests/test_models/test_tweetynet.py, from where we s…
NickleDave May 5, 2024
316e7a3
Fix test for DiceLoss, change trainer_type fixture back to device fix…
NickleDave May 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions doc/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,16 @@ and dataclasses that represent tables from those files.
:recursive:

config.config
config.dataset
config.eval
config.learncurve
config.load
config.model
config.parse
config.predict
config.prep
config.spect_params
config.train
config.trainer
config.validators

Datasets
Expand Down Expand Up @@ -265,10 +267,10 @@ used by multiple other modules.
:template: module.rst
:recursive:

common.accelerator
common.annotation
common.constants
common.converters
common.device
common.files
common.labels
common.learncurve
Expand Down
7 changes: 4 additions & 3 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@
"announcement":
""" 🚧 vak version 1.0.0 is in development! 🚧
📣 Test out the alpha release: <code>pip install vak==1.0.0a3</code>. 📣
For more info, please see
<a href="https://forum.vocalpy.org/t/vak-1-0-0a1-released/55"> this forum post<a>.
For more info, please see
<a href="https://forum.vocalpy.org/t/vak-1-0-0a1-released/55"> this forum post<a>.
""",
"sidebar_hide_name": True,
"light_css_variables": {
Expand Down Expand Up @@ -217,7 +217,8 @@
"python": ("https://docs.python.org/3/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
"pytorch": ("https://pytorch.org/docs/stable/", None)
"pytorch": ("https://pytorch.org/docs/stable/", None),
"lightning": ("https://lightning.ai/docs/pytorch/stable/", None),
}

# -- Options for todo extension ----------------------------------------------
Expand Down
9 changes: 8 additions & 1 deletion doc/toml/gy6or6_eval.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ batch_size = 11
# num_workers: number of workers to use when loading data with multiprocessing
num_workers = 16
# device: name of device to run model on, one of "cuda", "cpu"
device = "cuda"

# output_dir: directory where output should be saved, as a sub-directory within `output_dir`
output_dir = "/PATH/TO/FOLDER/results/eval"
# dataset_path : path to dataset created by prep
Expand Down Expand Up @@ -72,3 +72,10 @@ window_size = 176
# Note we do not specify any options for the model, and just use the defaults
# We need to put this table here though so we know which model we are using
[vak.eval.model.TweetyNet]

# this sub-table configures the `lightning.pytorch.Trainer`
[vak.eval.trainer]
# setting to 'gpu' means "train models on 'gpu' (not 'cpu')"
accelerator = "gpu"
# use the first GPU (numbering starts from 0)
devices = [0]
9 changes: 8 additions & 1 deletion doc/toml/gy6or6_predict.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ batch_size = 1
# num_workers: number of workers to use when loading data with multiprocessing
num_workers = 4
# device: name of device to run model on, one of "cuda", "cpu"
device = "cuda"

# output_dir: directory where output should be saved, as a sub-directory within `output_dir`
output_dir = "/PATH/TO/FOLDER/results/predict"
# annot_csv_filename
Expand Down Expand Up @@ -67,3 +67,10 @@ window_size = 176
# Note we do not specify any options for the network, and just use the defaults
# We need to put this table here though, to indicate which model we are using.
[vak.predict.model.TweetyNet]

# this sub-table configures the `lightning.pytorch.Trainer`
[vak.predict.trainer]
# setting to 'gpu' means "train models on 'gpu' (not 'cpu')"
accelerator = "gpu"
# use the first GPU (numbering starts from 0)
devices = [0]
9 changes: 8 additions & 1 deletion doc/toml/gy6or6_train.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ patience = 4
# num_workers: number of workers to use when loading data with multiprocessing
num_workers = 4
# device: name of device to run model on, one of "cuda", "cpu"
device = "cuda"

# dataset_path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it

# dataset.params = parameters used for datasets
Expand All @@ -73,3 +73,10 @@ lr = 0.001
[vak.train.model.TweetyNet.network]
# hidden_size: the number of elements in the hidden state in the recurrent layer of the network
hidden_size = 256

# this sub-table configures the `lightning.pytorch.Trainer`
[vak.train.trainer]
# setting to 'gpu' means "train models on 'gpu' (not 'cpu')"
accelerator = "gpu"
# use the first GPU (numbering starts from 0)
devices = [0]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"dask[dataframe] >=2.10.1",
"evfuncs >=0.3.4",
"joblib >=0.14.1",
"pytorch-lightning >=2.0.7",
"lightning >=2.0.7",
"matplotlib >=3.3.3",
"numpy >=1.18.1",
"pynndescent >=0.5.10",
Expand Down
2 changes: 1 addition & 1 deletion src/vak/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ def eval(toml_path: str | pathlib.Path) -> None:
eval_module.eval(
model_config=cfg.eval.model.asdict(),
dataset_config=cfg.eval.dataset.asdict(),
trainer_config=cfg.eval.trainer.asdict(),
checkpoint_path=cfg.eval.checkpoint_path,
labelmap_path=cfg.eval.labelmap_path,
output_dir=cfg.eval.output_dir,
num_workers=cfg.eval.num_workers,
batch_size=cfg.eval.batch_size,
spect_scaler_path=cfg.eval.spect_scaler_path,
device=cfg.eval.device,
post_tfm_kwargs=cfg.eval.post_tfm_kwargs,
)
2 changes: 1 addition & 1 deletion src/vak/cli/learncurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def learning_curve(toml_path):
learncurve.learning_curve(
model_config=cfg.learncurve.model.asdict(),
dataset_config=cfg.learncurve.dataset.asdict(),
trainer_config=cfg.learncurve.trainer.asdict(),
batch_size=cfg.learncurve.batch_size,
num_epochs=cfg.learncurve.num_epochs,
num_workers=cfg.learncurve.num_workers,
Expand All @@ -65,5 +66,4 @@ def learning_curve(toml_path):
val_step=cfg.learncurve.val_step,
ckpt_step=cfg.learncurve.ckpt_step,
patience=cfg.learncurve.patience,
device=cfg.learncurve.device,
)
2 changes: 1 addition & 1 deletion src/vak/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ def predict(toml_path):
predict_module.predict(
model_config=cfg.predict.model.asdict(),
dataset_config=cfg.predict.dataset.asdict(),
trainer_config=cfg.predict.trainer.asdict(),
checkpoint_path=cfg.predict.checkpoint_path,
labelmap_path=cfg.predict.labelmap_path,
num_workers=cfg.predict.num_workers,
timebins_key=cfg.prep.spect_params.timebins_key,
spect_scaler_path=cfg.predict.spect_scaler_path,
device=cfg.predict.device,
annot_csv_filename=cfg.predict.annot_csv_filename,
output_dir=cfg.predict.output_dir,
min_segment_dur=cfg.predict.min_segment_dur,
Expand Down
2 changes: 1 addition & 1 deletion src/vak/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def train(toml_path):
train_module.train(
model_config=cfg.train.model.asdict(),
dataset_config=cfg.train.dataset.asdict(),
trainer_config=cfg.train.trainer.asdict(),
batch_size=cfg.train.batch_size,
num_epochs=cfg.train.num_epochs,
num_workers=cfg.train.num_workers,
Expand All @@ -66,5 +67,4 @@ def train(toml_path):
val_step=cfg.train.val_step,
ckpt_step=cfg.train.ckpt_step,
patience=cfg.train.patience,
device=cfg.train.device,
)
4 changes: 2 additions & 2 deletions src/vak/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
"""

from . import (
accelerator,
annotation,
constants,
converters,
device,
files,
labels,
learncurve,
Expand All @@ -30,7 +30,7 @@
"annotation",
"constants",
"converters",
"device",
"accelerator",
"files",
"labels",
"learncurve",
Expand Down
16 changes: 16 additions & 0 deletions src/vak/common/accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch


def get_default() -> str:
"""Get default `accelerator` for :class:`lightning.pytorch.Trainer`.

Returns
-------
accelerator : str
Will be ``'gpu'`` if :func:`torch.cuda.is_available`
is ``True``, and ``'cpu'`` if not.
"""
if torch.cuda.is_available():
return "gpu"
else:
return "cpu"
16 changes: 0 additions & 16 deletions src/vak/common/device.py

This file was deleted.

48 changes: 31 additions & 17 deletions src/vak/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import pathlib

import pytorch_lightning as lightning
import lightning


def get_default_train_callbacks(
ckpt_root: str | pathlib.Path,
ckpt_step: int,
patience: int,
):
ckpt_callback = lightning.callbacks.ModelCheckpoint(
ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint(
dirpath=ckpt_root,
filename="checkpoint",
every_n_train_steps=ckpt_step,
Expand All @@ -20,7 +20,7 @@ def get_default_train_callbacks(
ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint"
ckpt_callback.FILE_EXTENSION = ".pt"

val_ckpt_callback = lightning.callbacks.ModelCheckpoint(
val_ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint(
monitor="val_acc",
dirpath=ckpt_root,
save_top_k=1,
Expand All @@ -31,7 +31,7 @@ def get_default_train_callbacks(
)
val_ckpt_callback.FILE_EXTENSION = ".pt"

early_stopping = lightning.callbacks.EarlyStopping(
early_stopping = lightning.pytorch.callbacks.EarlyStopping(
mode="max",
monitor="val_acc",
patience=patience,
Expand All @@ -42,33 +42,47 @@ def get_default_train_callbacks(


def get_default_trainer(
accelerator: str,
devices: int | list[int],
max_steps: int,
log_save_dir: str | pathlib.Path,
val_step: int,
default_callback_kwargs: dict | None = None,
device: str = "cuda",
) -> lightning.Trainer:
"""Returns an instance of ``lightning.Trainer``
) -> lightning.pytorch.Trainer:
"""Returns an instance of :class:`lightning.pytorch.Trainer`
with a default set of callbacks.
Used by ``vak.core`` functions."""

Used by :func:`vak.train.frame_classification`.
The default set of callbacks is provided by
:func:`get_default_train_callbacks`.

Parameters
----------
accelerator : str
devices : int, list of int
max_steps : int
log_save_dir : str, pathlib.Path
val_step : int
default_callback_kwargs : dict, optional

Returns
-------
trainer : lightning.pytorch.Trainer

"""
if default_callback_kwargs:
callbacks = get_default_train_callbacks(**default_callback_kwargs)
else:
callbacks = None

# TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691
if device == "cuda":
accelerator = "gpu"
else:
accelerator = "auto"

logger = lightning.loggers.TensorBoardLogger(save_dir=log_save_dir)
logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=log_save_dir)

trainer = lightning.Trainer(
trainer = lightning.pytorch.Trainer(
accelerator=accelerator,
devices=devices,
callbacks=callbacks,
val_check_interval=val_step,
max_steps=max_steps,
accelerator=accelerator,
logger=logger,
)
return trainer
5 changes: 5 additions & 0 deletions src/vak/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
prep,
spect_params,
train,
trainer,
validators,
)
from .config import Config
Expand All @@ -22,6 +23,8 @@
from .prep import PrepConfig
from .spect_params import SpectParamsConfig
from .train import TrainConfig
from .trainer import TrainerConfig


__all__ = [
"config",
Expand All @@ -34,6 +37,7 @@
"prep",
"spect_params",
"train",
"trainer",
"validators",
"Config",
"DatasetConfig",
Expand All @@ -44,4 +48,5 @@
"PrepConfig",
"SpectParamsConfig",
"TrainConfig",
"TrainerConfig",
]
Loading
Loading