Skip to content

Commit

Permalink
fix ckpt tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 29, 2024
1 parent 232c707 commit c02059d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
6 changes: 5 additions & 1 deletion alphafold3_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
accelerator = 'auto',
checkpoint_every: int = 1000,
checkpoint_folder: str = './checkpoints',
overwrite_checkpoints: bool = False,
fabric_kwargs: dict = dict(),
ema_kwargs: dict = dict()
):
Expand Down Expand Up @@ -199,6 +200,7 @@ def __init__(
# checkpointing logic

self.checkpoint_every = checkpoint_every
self.overwrite_checkpoints = overwrite_checkpoints
self.checkpoint_folder = Path(checkpoint_folder)

self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
Expand Down Expand Up @@ -367,7 +369,9 @@ def __call__(
self.wait()

if self.is_main and divisible_by(self.steps, self.checkpoint_every):
self.save(self.checkpoint_folder / f'af3.ckpt.{self.steps}.pt')
checkpoint_path = self.checkpoint_folder / f'af3.ckpt.{self.steps}.pt'

self.save(checkpoint_path, overwrite = self.overwrite_checkpoints)

self.wait()

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.0.61"
version = "0.0.62"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
5 changes: 4 additions & 1 deletion tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
os.environ['TYPECHECK'] = 'True'

from pathlib import Path

import pytest
import torch
from torch.utils.data import Dataset, DataLoader
Expand Down Expand Up @@ -133,7 +135,8 @@ def test_trainer():
batch_size = 1,
valid_every = 1,
grad_accum_every = 2,
checkpoint_every = 1
checkpoint_every = 1,
overwrite_checkpoints = True
)

trainer()
Expand Down

0 comments on commit c02059d

Please sign in to comment.