Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
xnuohz committed Oct 16, 2024
1 parent 9ee68e5 commit 1f61674
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 27 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,9 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124`

| | `cpu` | `cu118` | `cu121` | `cu124` |
| ----------- | ----- | ------- | ------- | ------- |
| **Linux** |||||
| **Windows** |||||
| **macOS** || | | |
| **Linux** | | | | |
| **Windows** | | | | |
| **macOS** | | | | |

#### PyTorch 2.3

Expand All @@ -399,9 +399,9 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, or `cu121` dependin

| | `cpu` | `cu118` | `cu121` |
| ----------- | ----- | ------- | ------- |
| **Linux** ||||
| **Windows** ||||
| **macOS** || | |
| **Linux** | | | |
| **Windows** | | | |
| **macOS** | | | |

**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, and PyTorch 2.2.0/2.2.1/2.2.2 (following the same procedure).
**For older versions, you might need to explicitly specify the latest supported version number** or install via `pip install --no-index` in order to prevent a manual installation from source.
Expand Down
21 changes: 16 additions & 5 deletions test/nn/models/test_molecule_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@
from torch.nn import ReLU
from torch.nn import Sequential as Seq

from torch_geometric.nn import GINConv, MoleculeGPT
from torch_geometric.nn import GINEConv, MoleculeGPT
from torch_geometric.nn.nlp import LLM


def test_molecule_gpt() -> None:
llm = LLM(
model_name='lmsys/vicuna-7b-v1.5',
# model_name='lmsys/vicuna-7b-v1.5',
model_name='DeepChem/ChemBERTa-77M-MTR',
num_params=7,
dtype=torch.bfloat16,
)

graph_encoder = GINConv(nn=Seq(Lin(16, 32), ReLU(), Lin(32, 32)),
train_eps=True)
graph_encoder = GINEConv(nn=Seq(Lin(16, 32), ReLU(), Lin(32, 32)),
train_eps=True, edge_dim=16)

smiles_encoder = LLM(
model_name='DeepChem/ChemBERTa-77M-MTR',
Expand All @@ -27,8 +28,18 @@ def test_molecule_gpt() -> None:
llm=llm,
graph_encoder=graph_encoder,
smiles_encoder=smiles_encoder,
use_lora=True,
mlp_out_channels=4096,
)

assert 'MoleculeGPT' in str(model)

x = torch.randn(10, 16)
edge_index = torch.tensor([
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 0],
])
edge_attr = torch.randn(edge_index.size(1), 16)
batch = torch.zeros(x.size(0), dtype=torch.long)
smiles = ['CCCCCCCCCC']

model(x, edge_index, batch, edge_attr, smiles)
40 changes: 24 additions & 16 deletions torch_geometric/nn/models/molecule_gpt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import List, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -46,43 +46,51 @@ def __init__(

self.word_embedding = self.llm.word_embedding
self.llm_generator = self.llm.llm
mlp_hidden_channels = self.gnn.out_channels
# TODO: Add Q-Former layer
self.projector = torch.nn.Sequential(
torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
torch.nn.ReLU(),
torch.nn.Linear(mlp_hidden_channels, mlp_out_channels),
).to(self.llm.device)

def encode(
def graph_encode(
self,
x: Tensor,
edge_index: Tensor,
batch: Tensor,
edge_attr: Optional[Tensor],
smiles: Optional[Tensor],
) -> Tensor:
x = x.to(self.llm.device)
edge_index = edge_index.to(self.llm.device)
if edge_attr is not None:
edge_attr = edge_attr.to(self.llm.device)
batch = batch.to(self.llm.device)
smiles = smiles.to(self.llm.device)

out = self.gnn(x, edge_index, edge_attr=edge_attr)
graph_embedding = scatter(out, batch, dim=0, reduce='mean')
smiles_embedding = self.smiles_encoder(smiles)
return graph_embedding, smiles_embedding
out = self.graph_encoder(x, edge_index, edge_attr=edge_attr)
return scatter(out, batch, dim=0, reduce='mean')

def forward(self):
def smiles_encode(
self,
smiles: List[str],
):
pass

def forward(
self,
x: Tensor,
edge_index: Tensor,
batch: Tensor,
edge_attr: Optional[Tensor],
smiles: List[str],
):
x = self.graph_encode(x, edge_index, batch,
edge_attr) # graph branch [bs, d]

import pdb
pdb.set_trace()

@torch.no_grad()
def inference(self):
pass

def __repr__(self) -> str:
return (f'{self.__class__.__name__}(\n'
f' llm={self.llm},\n'
f' gnn={self.gnn},\n'
f' graph={self.graph_encoder},\n'
f' smiles={self.smiles_encoder},\n'
f')')

0 comments on commit 1f61674

Please sign in to comment.