Skip to content

Commit

Permalink
WIP - Create checkpoint at end of each epoch.
Browse files Browse the repository at this point in the history
  • Loading branch information
drewoldag committed Oct 9, 2024
1 parent 6b8d978 commit 8e9359e
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions src/fibad/pytorch_ignite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ignite.distributed as idist
import torch
from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine

Check warning on line 8 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L8

Added line #L8 was not covered by tests
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils.data import Dataset

Expand Down Expand Up @@ -170,6 +171,34 @@ def create_trainer(model) -> Engine:
model = idist.auto_model(model)
trainer = create_engine("train_step", device, model)

to_save = {

Check warning on line 174 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L174

Added line #L174 was not covered by tests
'model': model,
'optimizer': model.optimizer,
'trainer': trainer,
}

latest_checkpoint = Checkpoint(

Check warning on line 180 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L180

Added line #L180 was not covered by tests
to_save,
DiskSaver('./results/checkpoints', create_dir=True, require_empty=False),
n_saved=1,
global_step_transform=global_step_from_engine(trainer),
filename_pattern="{name}.{ext}",
)

# neg_loss_score = Checkpoint.get_default_score_fn("loss", -1.0)
# best_checkpoint = Checkpoint(
# to_save,
# DiskSaver('./results/checkpoints', create_dir=True, require_empty=False),
# n_saved=1,
# global_step_transform=global_step_from_engine(trainer),
# score_name="loss",
# score_function=neg_loss_score,
# )


# prev_checkpoint = torch.load('./results/checkpoints/checkpoint.pt', map_location=device)
# Checkpoint.load_objects(to_load=to_save, checkpoint=prev_checkpoint)

@trainer.on(Events.STARTED)
def log_training_start(trainer):
logger.info(f"Training model on device: {device}")
Expand All @@ -184,6 +213,9 @@ def log_training_loss(trainer):
logger.info(f"Epoch {trainer.state.epoch} run time: {trainer.state.times['EPOCH_COMPLETED']:.2f}[s]")
logger.info(f"Epoch {trainer.state.epoch} metrics: {trainer.state.output}")

trainer.add_event_handler(Events.EPOCH_COMPLETED, latest_checkpoint)

Check warning on line 216 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L216

Added line #L216 was not covered by tests
# trainer.add_event_handler(Events.EPOCH_COMPLETED, best_checkpoint)

@trainer.on(Events.COMPLETED)
def log_total_time(trainer):
logger.info(f"Total training time: {trainer.state.times['COMPLETED']:.2f}[s]")
Expand Down

0 comments on commit 8e9359e

Please sign in to comment.