diff --git a/alphafold3_pytorch/trainer.py b/alphafold3_pytorch/trainer.py index 72139525..03ceff64 100644 --- a/alphafold3_pytorch/trainer.py +++ b/alphafold3_pytorch/trainer.py @@ -112,6 +112,7 @@ def __init__( accelerator = 'auto', checkpoint_every: int = 1000, checkpoint_folder: str = './checkpoints', + overwrite_checkpoints: bool = False, fabric_kwargs: dict = dict(), ema_kwargs: dict = dict() ): @@ -199,6 +200,7 @@ def __init__( # checkpointing logic self.checkpoint_every = checkpoint_every + self.overwrite_checkpoints = overwrite_checkpoints self.checkpoint_folder = Path(checkpoint_folder) self.checkpoint_folder.mkdir(exist_ok = True, parents = True) @@ -367,7 +369,9 @@ def __call__( self.wait() if self.is_main and divisible_by(self.steps, self.checkpoint_every): - self.save(self.checkpoint_folder / f'af3.ckpt.{self.steps}.pt') + checkpoint_path = self.checkpoint_folder / f'af3.ckpt.{self.steps}.pt' + + self.save(checkpoint_path, overwrite = self.overwrite_checkpoints) self.wait() diff --git a/pyproject.toml b/pyproject.toml index 7047dfca..6ec09a92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "alphafold3-pytorch" -version = "0.0.61" +version = "0.0.62" description = "Alphafold 3 - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 11a5a90e..e902bf0b 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -1,6 +1,8 @@ import os os.environ['TYPECHECK'] = 'True' +from pathlib import Path + import pytest import torch from torch.utils.data import Dataset, DataLoader @@ -133,7 +135,8 @@ def test_trainer(): batch_size = 1, valid_every = 1, grad_accum_every = 2, - checkpoint_every = 1 + checkpoint_every = 1, + overwrite_checkpoints = True ) trainer()