Skip to content

Commit

Permalink
Add NLMs (#270)
Browse files Browse the repository at this point in the history
* Update pyproject.toml

* Create nlm.py

* Update data_utils.py

* Update plm.py

* Update inputs.py

* Update alphafold3.py

* Update test_af3.py
  • Loading branch information
amorehead authored Sep 18, 2024
1 parent 0bc7541 commit f7cdbef
Show file tree
Hide file tree
Showing 7 changed files with 348 additions and 31 deletions.
71 changes: 70 additions & 1 deletion alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
CONSTRAINTS,
CONSTRAINTS_MASK_VALUE,
IS_MOLECULE_TYPES,
IS_NA_INDICES,
IS_NON_NA_INDICES,
IS_PROTEIN_INDEX,
IS_DNA_INDEX,
IS_RNA_INDEX,
Expand All @@ -70,6 +72,9 @@
IS_RNA,
IS_LIGAND,
IS_METAL_ION,
MAX_DNA_NUCLEOTIDE_ID,
MIN_RNA_NUCLEOTIDE_ID,
MISSING_RNA_NUCLEOTIDE_ID,
NUM_HUMAN_AMINO_ACIDS,
NUM_MOLECULE_IDS,
NUM_MSA_ONE_HOT,
Expand All @@ -85,6 +90,11 @@
get_residue_constants,
)

from alphafold3_pytorch.nlm import (
NLMEmbedding,
NLMRegistry,
remove_nlms
)
from alphafold3_pytorch.plm import (
PLMEmbedding,
PLMRegistry,
Expand Down Expand Up @@ -149,7 +159,8 @@
dmf - additional msa feats derived from msa (has_deletion and deletion_value)
dtf - additional token feats derived from msa (profile and deletion_mean)
dac - additional pairwise token constraint embeddings
dpe - additional protein language model embeddings from esm
dpe - additional protein language model embeddings
dne - additional nucleotide language model embeddings
t - templates
s - msa
r - registers
Expand Down Expand Up @@ -5957,7 +5968,9 @@ def __init__(
detach_when_recycling = True,
pdb_training_set=True,
plm_embeddings: PLMEmbedding | tuple[PLMEmbedding, ...] | None = None,
nlm_embeddings: NLMEmbedding | tuple[NLMEmbedding, ...] | None = None,
plm_kwargs: dict | tuple[dict, ...] | None = None,
nlm_kwargs: dict | tuple[dict, ...] | None = None,
constraints: List[CONSTRAINTS] | None = None,
):
super().__init__()
Expand Down Expand Up @@ -6033,6 +6046,34 @@ def __init__(

self.to_plm_embeds = LinearNoBias(concatted_plm_embed_dim, dim_single)

# optional nucleotide language model(s) (NLM) embeddings

self.nlms = None

if exists(nlm_embeddings):
self.nlms = ModuleList([])

for one_nlm_embedding, one_nlm_kwargs in zip_longest(
cast_tuple(nlm_embeddings), cast_tuple(nlm_kwargs)
):
assert (
one_nlm_embedding in NLMRegistry
), f"Received invalid NLM embedding name: {one_nlm_embedding}. Acceptable ones are {list(NLMRegistry.keys())}."

constructor = NLMRegistry.get(one_nlm_embedding)

one_nlm_kwargs = default(one_nlm_kwargs, {})
nlm = constructor(**one_nlm_kwargs)

freeze_(nlm)

self.nlms.append(nlm)

if exists(self.nlms):
concatted_nlm_embed_dim = sum([nlm.embed_dim for nlm in self.nlms])

self.to_nlm_embeds = LinearNoBias(concatted_nlm_embed_dim, dim_single)

# atoms per window

self.atoms_per_window = atoms_per_window
Expand Down Expand Up @@ -6261,10 +6302,12 @@ def device(self):
return self.zero.device

@remove_plms
@remove_nlms
def state_dict(self, *args, **kwargs):
return super().state_dict(*args, **kwargs)

@remove_plms
@remove_nlms
def load_state_dict(self, *args, **kwargs):
return super().load_state_dict(*args, **kwargs)

Expand Down Expand Up @@ -6590,6 +6633,32 @@ def forward(

single_init = single_init + single_plm_init

# handle maybe nucleotide language model (NLM) embeddings

if exists(self.nlms):
na_ids = torch.where(
is_molecule_types[..., IS_NA_INDICES].any(dim=-1)
& (
(molecule_ids < MIN_RNA_NUCLEOTIDE_ID) | (molecule_ids > MAX_DNA_NUCLEOTIDE_ID)
),
MISSING_RNA_NUCLEOTIDE_ID,
molecule_ids,
)
molecule_na_ids = torch.where(
is_molecule_types[..., IS_NON_NA_INDICES].any(dim=-1),
-1,
na_ids,
)

nlm_embeds = [nlm(molecule_na_ids) for nlm in self.nlms]

# concat all NLM embeddings and project and add to single init

all_nlm_embeds = torch.cat(nlm_embeds, dim=-1)
single_nlm_init = self.to_nlm_embeds(all_nlm_embeds)

single_init = single_init + single_nlm_init

# relative positional encoding

relative_position_encoding = self.relative_position_encoding(
Expand Down
8 changes: 8 additions & 0 deletions alphafold3_pytorch/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,11 @@
IS_DNA_INDEX = 2
IS_LIGAND_INDEX = -2
IS_METAL_ION_INDEX = -1

IS_BIOMOLECULE_INDICES = slice(0, 3)
IS_NON_PROTEIN_INDICES = slice(1, 5)
IS_NA_INDICES = slice(1, 3)
IS_NON_NA_INDICES = [0, 3, 4]

IS_PROTEIN, IS_RNA, IS_DNA, IS_LIGAND, IS_METAL_ION = tuple(
(IS_MOLECULE_TYPES + i if i < 0 else i)
Expand All @@ -144,6 +147,11 @@
NUM_HUMAN_AMINO_ACIDS = len(HUMAN_AMINO_ACIDS) - 1 # exclude unknown amino acid type
NUM_MSA_ONE_HOT = len(HUMAN_AMINO_ACIDS) + len(RNA_NUCLEOTIDES) + len(DNA_NUCLEOTIDES) + 1

MIN_RNA_NUCLEOTIDE_ID = len(HUMAN_AMINO_ACIDS)
MAX_DNA_NUCLEOTIDE_ID = len(HUMAN_AMINO_ACIDS) + len(RNA_NUCLEOTIDES) + len(DNA_NUCLEOTIDES) - 1

MISSING_RNA_NUCLEOTIDE_ID = len(HUMAN_AMINO_ACIDS) + len(RNA_NUCLEOTIDES) - 1

DEFAULT_NUM_MOLECULE_MODS = 4 # `mod_protein`, `mod_rna`, `mod_dna`, and `mod_unk`
ADDITIONAL_MOLECULE_FEATS = 5

Expand Down
117 changes: 117 additions & 0 deletions alphafold3_pytorch/nlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from functools import wraps

import torch
from beartype.typing import Literal
from torch import tensor
from torch.nn import Module

from alphafold3_pytorch.common.biomolecule import get_residue_constants
from alphafold3_pytorch.inputs import IS_DNA, IS_RNA
from alphafold3_pytorch.tensor_typing import Float, Int, typecheck
from alphafold3_pytorch.utils.data_utils import join

# functions

def remove_nlms(fn):
"""Decorator to remove NLMs from the model before calling the inner function and then restore
them afterwards."""

@wraps(fn)
def inner(self, *args, **kwargs):
has_nlms = hasattr(self, "nlms")
if has_nlms:
nlms = self.nlms
delattr(self, "nlms")

out = fn(self, *args, **kwargs)

if has_nlms:
self.nlms = nlms

return out

return inner


# constants

rna_constants = get_residue_constants(res_chem_index=IS_RNA)
dna_constants = get_residue_constants(res_chem_index=IS_DNA)

rna_restypes = rna_constants.restypes + ["X"]
dna_restypes = dna_constants.restypes + ["X"]

rna_min_restype_num = rna_constants.min_restype_num
dna_min_restype_num = dna_constants.min_restype_num

RINALMO_MASK_TOKEN = "-" # nosec

# class


class RiNALMoWrapper(Module):
"""A wrapper for the RiNALMo model to provide NLM embeddings."""

def __init__(self):
super().__init__()
from multimolecule import RiNALMoModel, RnaTokenizer

self.register_buffer("dummy", tensor(0), persistent=False)

self.tokenizer = RnaTokenizer.from_pretrained(
"multimolecule/rinalmo", replace_T_with_U=False
)
self.model = RiNALMoModel.from_pretrained("multimolecule/rinalmo")

self.embed_dim = 1280

@torch.no_grad()
@typecheck
def forward(
self, na_ids: Int["b n"] # type: ignore
) -> Float["b n dne"]: # type: ignore
"""Get NLM embeddings for a batch of (pseudo-)nucleotide sequences.
:param na_ids: A batch of nucleotide residue indices.
:return: The NLM embeddings for the input sequences.
"""
device, seq_len = self.dummy.device, na_ids.shape[-1]

sequence_data = [
join(
[
(
RINALMO_MASK_TOKEN
if i == -1
else (
rna_restypes[i - rna_min_restype_num]
if rna_min_restype_num <= i < dna_min_restype_num
else dna_restypes[i - dna_min_restype_num]
)
)
for i in ids
]
)
for ids in na_ids
]

# encode to ids

inputs = self.tokenizer(sequence_data, return_tensors="pt").to(device)

# forward through nlm

embeddings = self.model(inputs.input_ids, attention_mask=inputs.attention_mask)

# remove prefix

nlm_embeddings = embeddings.last_hidden_state[:, 1 : (seq_len + 1)]

return nlm_embeddings


# NLM embedding type and registry

NLMRegistry = dict(rinalmo=RiNALMoWrapper)

NLMEmbedding = Literal["rinalmo"]
52 changes: 31 additions & 21 deletions alphafold3_pytorch/plm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,35 +9,39 @@
from alphafold3_pytorch.common.biomolecule import get_residue_constants
from alphafold3_pytorch.inputs import IS_PROTEIN
from alphafold3_pytorch.tensor_typing import Float, Int, typecheck
from alphafold3_pytorch.utils.data_utils import join

# functions

def join(arr, delimiter = ''): # just redo an ugly part of python
return delimiter.join(arr)

def remove_plms(fn):
"""Decorator to remove PLMs from the model before calling the inner function and then restore
them afterwards."""

@wraps(fn)
def inner(self, *args, **kwargs):
has_plms = hasattr(self, 'plms')
has_plms = hasattr(self, "plms")
if has_plms:
plms = self.plms
delattr(self, 'plms')
delattr(self, "plms")

out = fn(self, *args, **kwargs)

if has_plms:
self.plms = plms

return out

return inner


# constants

aa_constants = get_residue_constants(res_chem_index=IS_PROTEIN)
restypes = aa_constants.restypes + ["X"]

ESM_MASK_TOKEN = "-"
PROST_T5_MASK_TOKEN = "X"
ESM_MASK_TOKEN = "-" # nosec
PROST_T5_MASK_TOKEN = "X" # nosec

# class

Expand Down Expand Up @@ -70,7 +74,9 @@ def forward(
:param aa_ids: A batch of amino acid residue indices.
:return: The PLM embeddings for the input sequences.
"""
device, repr_layer = self.dummy.device, self.repr_layer
device, seq_len, repr_layer = self.dummy.device, aa_ids.shape[-1], self.repr_layer

# following the readme at https://github.com/facebookresearch/esm

sequence_data = [
(
Expand All @@ -80,18 +86,21 @@ def forward(
for mol_idx, ids in enumerate(aa_ids)
]

# encode to IDs

_, _, batch_tokens = self.batch_converter(sequence_data)
batch_tokens = batch_tokens.to(device)

# forward through plm

self.model.eval()
results = self.model(batch_tokens, repr_layers=[repr_layer])

token_representations = results["representations"][repr_layer]
embeddings = results["representations"][repr_layer]

# remove prefix

sequence_representations = []
for i, (_, seq) in enumerate(sequence_data):
sequence_representations.append(token_representations[i, 1 : len(seq) + 1])
plm_embeddings = torch.stack(sequence_representations, dim=0)
plm_embeddings = embeddings[:, 1 : (seq_len + 1)]

return plm_embeddings

Expand Down Expand Up @@ -121,20 +130,21 @@ def forward(
"""
device, seq_len = self.dummy.device, aa_ids.shape[-1]

str_sequences = [
join([(PROST_T5_MASK_TOKEN if i == -1 else restypes[i]) for i in ids]) for ids in aa_ids
]

# following the readme at https://github.com/mheinzinger/ProstT5

str_sequences = [
join(list(re.sub(r"[UZOB]", "X", str_seq)), " ") for str_seq in str_sequences
sequence_data = [
join([(PROST_T5_MASK_TOKEN if i == -1 else restypes[i]) for i in ids])
for ids in aa_ids
]

sequence_data = [
join(list(re.sub(r"[UZOB]", "X", str_seq)), " ") for str_seq in sequence_data
]

# encode to ids

inputs = self.tokenizer.batch_encode_plus(
str_sequences, add_special_tokens=True, padding="longest", return_tensors="pt"
sequence_data, add_special_tokens=True, padding="longest", return_tensors="pt"
).to(device)

# forward through plm
Expand All @@ -143,8 +153,8 @@ def forward(

# remove prefix

plm_embedding = embeddings.last_hidden_state[:, 1 : (seq_len + 1)]
return plm_embedding
plm_embeddings = embeddings.last_hidden_state[:, 1 : (seq_len + 1)]
return plm_embeddings


# PLM embedding type and registry
Expand Down
Loading

0 comments on commit f7cdbef

Please sign in to comment.