Skip to content

Commit

Permalink
complete prostT5 integration (#253)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains authored Sep 14, 2024
1 parent 9d5d05a commit fe1a0fd
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 21 deletions.
13 changes: 7 additions & 6 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sh
from math import pi, sqrt
from pathlib import Path
from itertools import product
from itertools import product, zip_longest
from functools import partial, wraps
from collections import namedtuple

Expand Down Expand Up @@ -213,7 +213,7 @@ def compact(*args):
return tuple(filter(exists, args))

def cast_tuple(t, length = 1):
return (t,) if not isinstance(t, tuple) else ((t,) * length)
return t if isinstance(t, tuple) else ((t,) * length)

# tensor helpers

Expand Down Expand Up @@ -5960,7 +5960,7 @@ def __init__(
detach_when_recycling = True,
pdb_training_set=True,
plm_embeddings: PLMEmbeddings | tuple[PLMEmbedding, ...] | None = None,
plm_kwargs: dict | tuple[dict, ...] = dict(),
plm_kwargs: dict | tuple[dict, ...] | None = None,
constraint_embeddings: int | None = None,
):
super().__init__()
Expand Down Expand Up @@ -6011,12 +6011,13 @@ def __init__(
if exists(plm_embeddings):
self.plms = ModuleList([])

for one_plm_embedding, one_plm_kwargs in zip(cast_tuple(plm_embeddings), cast_tuple(plm_kwargs)):

assert one_plm_embedding in PLMRegistry
for one_plm_embedding, one_plm_kwargs in zip_longest(cast_tuple(plm_embeddings), cast_tuple(plm_kwargs)):
assert one_plm_embedding in PLMRegistry, f'received invalid plm embedding name {one_plm_embedding} - acceptable ones are {PLMRegistry.keys()}'
constructor = PLMRegistry.get(one_plm_embedding)

one_plm_kwargs = default(one_plm_kwargs, {})
plm = constructor(**one_plm_kwargs)

freeze_(plm)

self.plms.append(plm)
Expand Down
81 changes: 67 additions & 14 deletions alphafold3_pytorch/plm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from functools import partial

import torch
Expand All @@ -20,6 +21,16 @@
IS_PROTEIN,
)

# functions

def join(arr, delimiter = ''): # just redo an ugly part of python
return delimiter.join(arr)

# constants

aa_constants = get_residue_constants(res_chem_index = IS_PROTEIN)
restypes_index = dict(enumerate(aa_constants.restypes))

# class

class ESMWrapper(Module):
Expand All @@ -46,20 +57,10 @@ def forward(

device, repr_layer = self.dummy.device, self.repr_layer

aa_constants = get_residue_constants(res_chem_index=IS_PROTEIN)
sequence_data = [
(
f"molecule{i}",
"".join(
[
(
aa_constants.restypes[id]
if 0 <= id < len(aa_constants.restypes)
else "X"
)
for id in ids
]
),
join([restypes_index.get(i, 'X') for i in ids]),
)
for i, ids in enumerate(aa_ids)
]
Expand All @@ -79,10 +80,62 @@ def forward(

return plm_embeddings

# PLM embedding type and registry
class ProstT5Wrapper(Module):
def __init__(self):
super().__init__()
from transformers import T5Tokenizer, T5EncoderModel

self.register_buffer('dummy', tensor(0), persistent = False)

self.tokenizer = T5Tokenizer.from_pretrained('Rostlab/ProstT5', do_lower_case = False)
self.model = T5EncoderModel.from_pretrained("Rostlab/ProstT5")
self.embed_dim = 1024

def forward(
self,
aa_ids: Int['b n']
) -> Float['b n dpe']:

PLMEmbedding = Literal["esm2_t33_650M_UR50D"]
device, seq_len = self.dummy.device, aa_ids.shape[-1]

str_sequences = [
join([restypes_index.get(i, 'X') for i in ids])
for i, ids in enumerate(aa_ids)
]

# following the readme at https://github.com/mheinzinger/ProstT5

str_sequences = [join(list(re.sub(r"[UZOB]", "X", str_seq)), ' ') for str_seq in str_sequences]

# encode to ids

inputs = self.tokenizer.batch_encode_plus(
str_sequences,
add_special_tokens = True,
padding = "longest",
return_tensors = 'pt'
).to(device)

# forward through plm

embeddings = self.model(
inputs.input_ids,
attention_mask = inputs.attention_mask
)

# remove prefix

plm_embedding = embeddings.last_hidden_state[:, 1:(seq_len + 1)]
return plm_embedding

# PLM embedding type and registry

PLMRegistry = dict(
esm2_t33_650M_UR50D = partial(ESMWrapper, 'esm2_t33_650M_UR50D')
esm2_t33_650M_UR50D = partial(ESMWrapper, 'esm2_t33_650M_UR50D'),
prostT5 = ProstT5Wrapper
)

PLMEmbedding = Literal[
"esm2_t33_650M_UR50D",
"prostT5"
]
2 changes: 1 addition & 1 deletion tests/test_af3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,7 @@ def test_alphafold3_with_plm_embeddings():
num_molecule_mods=0,
dim_atom_inputs=77,
dim_template_feats=108,
plm_embeddings="esm2_t33_650M_UR50D",
plm_embeddings=("esm2_t33_650M_UR50D", "prostT5"),
)

# mock inputs
Expand Down

0 comments on commit fe1a0fd

Please sign in to comment.