diff --git a/alphafold3_pytorch/alphafold3.py b/alphafold3_pytorch/alphafold3.py index f38c16a4..8851539a 100644 --- a/alphafold3_pytorch/alphafold3.py +++ b/alphafold3_pytorch/alphafold3.py @@ -5959,7 +5959,7 @@ def __init__( checkpoint_diffusion_module = False, detach_when_recycling = True, pdb_training_set=True, - plm_embeddings: PLMEmbeddings | tuple[PLMEmbedding, ...] | None = None, + plm_embeddings: PLMEmbedding | tuple[PLMEmbedding, ...] | None = None, plm_kwargs: dict | tuple[dict, ...] | None = None, constraint_embeddings: int | None = None, ):