Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP - introducing tensorboard to fibad #119

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/notebooks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ Notebooks

.. toctree::

Introducing Jupyter Notebooks <notebooks/intro_notebook>
Training a simple model <notebooks/train_model>
66 changes: 0 additions & 66 deletions docs/notebooks/TrainingAModel.ipynb

This file was deleted.

84 changes: 0 additions & 84 deletions docs/notebooks/intro_notebook.ipynb

This file was deleted.

97 changes: 97 additions & 0 deletions docs/notebooks/train_model.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Intro to Training and Configurations\n",
"\n",
"First we import fibad and create a new fibad object, instantiated (implicitly), with the default configuration file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import fibad\n",
"\n",
"fibad_instance = fibad.Fibad()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For this demo, we'll make a few adjustments to the default configuration settings that the `fibad` object was instantiated with. By accessing the `.config` attribute of the fibad instance, we can modify any configuration value. Here we change which built in model to use, the dataset, batch size, number of epochs for training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fibad_instance.config[\"model\"][\"name\"] = \"ExampleCNN\"\n",
"fibad_instance.config[\"data_set\"][\"name\"] = \"CifarDataSet\"\n",
"fibad_instance.config[\"data_loader\"][\"batch_size\"] = 64\n",
"fibad_instance.config[\"train\"][\"epochs\"] = 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We call the `.train()` method to train the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fibad_instance.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The output of the training will be stored in a time-stamped directory under the `./results/`. By default, a copy of the final configuration used in training is persisted as `runtime_config.toml`. To run fibad again with the same configuration, you can reference the runtime_config.toml file.\n",
"\n",
"If running in another notebook, instantiate a fibad object like so:\n",
"```\n",
"new_fibad_instance = fibad.Fibad(config_file='./results/<timestamped_directory>/runtime_config.toml')\n",
"```\n",
"\n",
"Or from the command line:\n",
"```\n",
">> fibad train --runtime-config ./results/<timestamped_directory>/runtime_config.toml\n",
"```"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "fibad",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ dependencies = [
"toml", # Used to load configuration files as dictionaries
"torch", # Used for CNN model and in train.py
"torchvision", # Used in hsc data loader, example autoencoder, and CNN model data set
"tensorboardX", # Used to log training metrics
"tensorboard", # Used to log training metrics
]

[project.scripts]
Expand Down
21 changes: 21 additions & 0 deletions src/fibad/pytorch_ignite.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine
from ignite.handlers.tensorboard_logger import GradsScalarHandler, TensorboardLogger, WeightsHistHandler

Check warning on line 10 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L10

Added line #L10 was not covered by tests
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils.data import Dataset

Expand Down Expand Up @@ -214,10 +215,30 @@
greater_or_equal=True,
)

tensorboard_logger = TensorboardLogger(log_dir=results_directory)

Check warning on line 218 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L218

Added line #L218 was not covered by tests

if config["train"]["resume"]:
prev_checkpoint = torch.load(config["train"]["resume"], map_location=device)
Checkpoint.load_objects(to_load=to_save, checkpoint=prev_checkpoint)

tensorboard_logger.attach(

Check warning on line 224 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L224

Added line #L224 was not covered by tests
trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)
)

tensorboard_logger.attach(

Check warning on line 228 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L228

Added line #L228 was not covered by tests
trainer,
log_handler=WeightsHistHandler(model),
event_name=Events.ITERATION_COMPLETED(every=100),
)

tensorboard_logger.attach_output_handler(

Check warning on line 234 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L234

Added line #L234 was not covered by tests
trainer,
event_name=Events.ITERATION_COMPLETED(every=10),
tag="training",
output_transform=lambda loss: loss,
metric_names="all",
)

@trainer.on(Events.STARTED)
def log_training_start(trainer):
logger.info(f"Training model on device: {device}")
Expand Down