Skip to content

Commit

Permalink
fix trainer save / load
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 29, 2024
1 parent 5fe0cbd commit 99fd519
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions alphafold3_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,16 @@ def save(self, path: str | Path, overwrite = False):
steps = self.steps
)

torch.save(str(path), package)
torch.save(package, str(path))

def load(self, path: str | Path, strict = True):
if isinstance(path, str):
path = Path(path)

assert path.exists()

self.model.load(path)

package = torch.load(str(path))

if 'optimizer' in package:
Expand All @@ -233,8 +235,6 @@ def load(self, path: str | Path, strict = True):

self.steps = package.get('steps', 0)

self.model.load_state_dict(package['model'])

# shortcut methods

def wait(self):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.0.59"
version = "0.0.60"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
2 changes: 1 addition & 1 deletion tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_trainer():

# saving and loading from trainer

trainer.save('./some/nested/folder2/training')
trainer.save('./some/nested/folder2/training', overwrite = True)
trainer.load('./some/nested/folder2/training')

# also allow for loading Alphafold3 directly from training ckpt
Expand Down

0 comments on commit 99fd519

Please sign in to comment.