From 585aca0698086d8eb200d9d110065816666fd836 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 29 May 2024 08:36:20 -0700 Subject: [PATCH] complete saving and loading trainer states --- alphafold3_pytorch/trainer.py | 43 +++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- tests/test_trainer.py | 9 ++++++++ 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/alphafold3_pytorch/trainer.py b/alphafold3_pytorch/trainer.py index c782113a..71cdf698 100644 --- a/alphafold3_pytorch/trainer.py +++ b/alphafold3_pytorch/trainer.py @@ -1,5 +1,7 @@ from __future__ import annotations +from pathlib import Path + from alphafold3_pytorch.alphafold3 import Alphafold3 from typing import TypedDict @@ -187,6 +189,45 @@ def __init__( def is_main(self): return self.fabric.global_rank == 0 + # saving and loading + + def save(self, path: str | Path, overwrite = False): + if isinstance(path, str): + path = Path(path) + + assert not path.is_dir() and (not path.exists() or overwrite) + + path.parent.mkdir(exist_ok = True, parents = True) + + package = dict( + model = self.model.state_dict_with_init_args, + optimizer = self.optimizer.state_dict(), + scheduler = self.scheduler.state_dict(), + steps = self.steps + ) + + torch.save(str(path), package) + + def load(self, path: str | Path, strict = True): + if isinstance(path, str): + path = Path(path) + + assert path.exists() + + package = torch.load(str(path)) + + if 'optimizer' in package: + self.optimizer.load_state_dict(package['optimizer']) + + if 'scheduler' in package: + self.scheduler.load_state_dict(package['scheduler']) + + self.steps = package.get('steps', 0) + + self.model.load_state_dict(package['model']) + + # shortcut methods + def wait(self): self.fabric.barrier() @@ -196,6 +237,8 @@ def print(self, *args, **kwargs): def log(self, **log_data): self.fabric.log_dict(log_data, step = self.steps) + # main train forwards + def __call__( self ): diff --git a/pyproject.toml b/pyproject.toml index 0f8c5845..0d8adcc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "alphafold3-pytorch" -version = "0.0.57" +version = "0.0.58" description = "Alphafold 3 - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_trainer.py b/tests/test_trainer.py index c758d8c1..49f1de2c 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -134,3 +134,12 @@ def test_trainer(): ) trainer() + + # saving and loading from trainer + + trainer.save('./some/nested/folder2/training') + trainer.load('./some/nested/folder2/training') + + # also allow for loading Alphafold3 directly from training ckpt + + alphafold3 = Alphafold3.init_and_load('./some/nested/folder2/training')