Skip to content

Commit

Permalink
complete saving and loading trainer states
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 29, 2024
1 parent d80a997 commit 585aca0
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 1 deletion.
43 changes: 43 additions & 0 deletions alphafold3_pytorch/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from pathlib import Path

from alphafold3_pytorch.alphafold3 import Alphafold3

from typing import TypedDict
Expand Down Expand Up @@ -187,6 +189,45 @@ def __init__(
def is_main(self):
return self.fabric.global_rank == 0

# saving and loading

def save(self, path: str | Path, overwrite = False):
if isinstance(path, str):
path = Path(path)

assert not path.is_dir() and (not path.exists() or overwrite)

path.parent.mkdir(exist_ok = True, parents = True)

package = dict(
model = self.model.state_dict_with_init_args,
optimizer = self.optimizer.state_dict(),
scheduler = self.scheduler.state_dict(),
steps = self.steps
)

torch.save(str(path), package)

def load(self, path: str | Path, strict = True):
if isinstance(path, str):
path = Path(path)

assert path.exists()

package = torch.load(str(path))

if 'optimizer' in package:
self.optimizer.load_state_dict(package['optimizer'])

if 'scheduler' in package:
self.scheduler.load_state_dict(package['scheduler'])

self.steps = package.get('steps', 0)

self.model.load_state_dict(package['model'])

# shortcut methods

def wait(self):
self.fabric.barrier()

Expand All @@ -196,6 +237,8 @@ def print(self, *args, **kwargs):
def log(self, **log_data):
self.fabric.log_dict(log_data, step = self.steps)

# main train forwards

def __call__(
self
):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.0.57"
version = "0.0.58"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
9 changes: 9 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,12 @@ def test_trainer():
)

trainer()

# saving and loading from trainer

trainer.save('./some/nested/folder2/training')
trainer.load('./some/nested/folder2/training')

# also allow for loading Alphafold3 directly from training ckpt

alphafold3 = Alphafold3.init_and_load('./some/nested/folder2/training')

0 comments on commit 585aca0

Please sign in to comment.