Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mol opt #24

Merged
merged 48 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
9587e84
add mol opt
tigranfah Apr 19, 2024
2a9cc8c
mol opt+rej sample
tigranfah Apr 22, 2024
ab876f9
minor logging changes
tigranfah Apr 22, 2024
56e9232
rej-sample-v1
tigranfah Apr 24, 2024
de7a695
add molecules pool dump
tigranfah Apr 27, 2024
40f7dde
add pool dump
tigranfah Apr 29, 2024
7559e5b
add feature
tigranfah May 3, 2024
a023281
don't enter scoring code if oracle has exceeded budget
philippguevorguian May 7, 2024
8ae3162
add file generation utility
philippguevorguian May 7, 2024
f2ee468
rej-sample-v2
tigranfah May 10, 2024
44cfb4b
rej-sample-v2.1
tigranfah May 13, 2024
898f5c1
refine
tigranfah May 14, 2024
a68b4f2
rej-sample-v2
tigranfah May 10, 2024
f5e1ca2
rej-sample-v2.1
tigranfah May 13, 2024
52b3a64
merge
tigranfah May 14, 2024
ca32500
merge
tigranfah May 14, 2024
cff4db7
rej-sample-v2 refac
tigranfah May 15, 2024
efd8106
remove dublicates if any from the optim process
tigranfah May 17, 2024
cade253
add hash type for molecular entries to store
philippguevorguian May 17, 2024
70642da
merge with local
philippguevorguian May 17, 2024
a805c1b
add logit processors to package
philippguevorguian May 17, 2024
20d3fe0
add train condition
tigranfah May 17, 2024
bfa32d7
don't give oracle score tag before the first training
tigranfah May 18, 2024
d51cbfb
add validation set to fine tunning
tigranfah May 20, 2024
1502e65
replace [ORACLE_SCORE] with [PROPERTY] tag
tigranfah May 20, 2024
188f384
add additional properties
tigranfah May 21, 2024
d95e4d4
correct properties order
tigranfah May 21, 2024
800278b
fix entry dublicates in pool issue
tigranfah May 21, 2024
1306842
pre merge
philippguevorguian May 22, 2024
d9f85d1
Merge branch 'mol_opt' of https://github.com/YerevaNN/ChemLactica int…
philippguevorguian May 22, 2024
b98a2cb
add validation batch size, to avoid memory error
tigranfah May 22, 2024
32cc5aa
small fix
tigranfah May 22, 2024
0ac3a7b
small fixes
tigranfah May 25, 2024
d8ac65f
add hparam config
tigranfah May 29, 2024
05af82d
add no slurm parallel runs to mol_opt
tigranfah May 29, 2024
16a00ca
keep optim state during the fine-tuning of mol optim process
tigranfah May 29, 2024
832906c
refine
tigranfah Jun 3, 2024
1f8142f
add lr annealing/not annealing
tigranfah Jun 6, 2024
91d9390
add ending lr in config
tigranfah Jun 6, 2024
2cf0f8b
remove multiprocessing
tigranfah Jun 6, 2024
9c96919
change sampling from pool
tigranfah Jun 7, 2024
5d7a258
fix
tigranfah Jun 7, 2024
be4096d
add model selection based on validation loss
tigranfah Jun 8, 2024
a84c79c
remove model selection with validation loss, because of slowness
tigranfah Jun 9, 2024
c393b1a
take random permutation of top elements from the pool, when construct…
tigranfah Jun 11, 2024
5537720
add custom model selection with fine-tuning validation loss
tigranfah Jun 11, 2024
10feca8
take a random subset from pool, when creating a prompt
tigranfah Jun 20, 2024
52692ec
fix
tigranfah Jun 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added chemlactica/mol_opt/__init__.py
Empty file.
191 changes: 191 additions & 0 deletions chemlactica/mol_opt/optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import torch
from datasets import Dataset
import multiprocessing
import gc
import re
import numpy as np
from collections import namedtuple
from chemlactica.mol_opt.utils import MoleculeEntry, MoleculePool, generate_random_number
from chemlactica.mol_opt.tunning import supervised_fine_tune


def create_optimization_prompts(num_prompts, molecule_pool, max_similars_in_prompt: int, sim_range):
prompts = []
for i in range(num_prompts):
similars_in_prompt = molecule_pool.random_subset(max_similars_in_prompt)
prompt = "</s>"
for mol in similars_in_prompt:
prompt += f"[SIMILAR]{mol.smiles} {generate_random_number(sim_range[0], sim_range[1]):.2f}[/SIMILAR]"
prompt += "[START_SMILES]"
prompts.append(prompt)
return prompts


def create_molecule_entry(output_text):
start_smiles_tag, end_smiles_tag = "[START_SMILES]", "[END_SMILES]"
start_ind = output_text.find(start_smiles_tag)
end_ind = output_text.find(end_smiles_tag)
if start_ind == -1 or end_ind == -1:
return None

generated_smiles = output_text[start_ind+len(start_smiles_tag):end_ind]
for similar in output_text.split("[SIMILAR]")[1:-1]:
similar_smiles = similar.split(" ")[0]
if generated_smiles == similar_smiles:
return None
try:
return MoleculeEntry(
smiles=generated_smiles,
)
except:
return None


def query_molecule_properties(model, tokenizer, smiles, property_tag, prop_pred_kwargs):
property_start_tag, property_end_tag = f"[{property_tag}]", f"[/{property_tag}]"
prompts = [f"</s>[START_SMILES]{smiles}[END_SMILES][{property_tag}]"]
data = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
del data["token_type_ids"]
outputs = model.generate(
**data,
**prop_pred_kwargs
)
predicted_property_values = []
output_texts = tokenizer.batch_decode(outputs)
for output_text in output_texts:
start_ind = output_text.find(property_start_tag)
end_ind = output_text.find(property_end_tag)
if start_ind != -1 and end_ind != -1:
predicted_property_values.append(output_text[start_ind+len(property_start_tag):end_ind])
else:
predicted_property_values.append(None)
return predicted_property_values


def optimize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is not clear where we call this optimize function and what the input model and config are. please push the code that runs this module.

model, tokenizer,
oracle, config
):
file = open(config["log_dir"], "w")
print("config", config)
# print("molecule generation arguments", config["generation_config"])
molecule_pool = MoleculePool(config["molecule_pool_size"])

if config["strategy"] == "rej-sample":
training_entries = []

num_iter = 1
while True:
model.eval()
current_entries = []
while len(current_entries) < config["num_gens_per_iter"]:
prompts = create_optimization_prompts(
config["num_gens_per_iter"], molecule_pool,
max_similars_in_prompt=config["max_similars_in_prompt"],
sim_range=config["sim_range"]
)
output_texts = []
generation_batch_size = 100
for i in range(0, len(prompts), generation_batch_size):
prompt_batch = prompts[i: min(len(prompts), i + generation_batch_size)]
data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device)
del data["token_type_ids"]
output = model.generate(
**data,
**config["generation_config"]
)
gc.collect()
torch.cuda.empty_cache()
output_texts.extend(tokenizer.batch_decode(output))

with multiprocessing.Pool(processes=config["num_processes"]) as pol:
for i, entry in enumerate(pol.map(create_molecule_entry, output_texts)):
if entry:
entry.score = oracle(entry.smiles)
entry.additional_properties["prompt"] = prompts[i]
current_entries.append(entry)
file.write(f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n")
if oracle.finish or len(current_entries) >= config["num_gens_per_iter"]:
break

if oracle.finish:
break

current_entries = list(np.unique(current_entries))[::-1]
num_iter += 1
if oracle.finish:
break

if config["strategy"] == "rej-sample":
top_k = int(len(current_entries) * config["rej_sample_config"]["rej_perc"])
training_entries.extend(current_entries[:top_k])
training_entries = list(np.unique(training_entries))[::-1]
if len(training_entries) >= config["rej_sample_config"]["num_samples_per_round"]:
print(f"Num of train examples {len(training_entries)}.")
file.write("Training entries\n")
for i, mol in enumerate(training_entries):
file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n")
train_dataset = Dataset.from_dict({
"sample": [
f"{entry.additional_properties['prompt']}{entry.smiles}[END_SMILES]</s>"
for entry in training_entries
]
})
config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"]
supervised_fine_tune(model, tokenizer, train_dataset, config["rej_sample_config"])
training_entries = []
gc.collect()
torch.cuda.empty_cache()

molecule_pool.add(current_entries)
file.write("Molecule pool\n")
for i, mol in enumerate(molecule_pool.molecule_entries):
file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n")


# def optimize_reinvent(
# model, prior_model,
# tokenizer, oracle,
# config
# ):
# print("molecule pool size", config["molecule_pool_size"])
# print("molecule generation arguments", config["mol_gen_kwargs"])
# molecule_pool = MoleculePool(config["molecule_pool_size"])

# num_iter = 1
# while True:
# generated_entries = []
# while len(generated_entries) < config["num_gens_per_iter"]:
# prompts = create_optimization_prompts(
# config["num_gens_per_iter"], molecule_pool,
# max_similars_in_prompt=config["max_similars_in_prompt"],
# sim_range=config["sim_range"]
# )
# output_texts = []
# generation_batch_size = 200
# for i in range(0, len(prompts), generation_batch_size):
# prompt_batch = prompts[i: min(len(prompts), i + generation_batch_size)]
# data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device)
# del data["token_type_ids"]
# output = model.generate(
# **data,
# **config["mol_gen_kwargs"]
# )
# gc.collect()
# torch.cuda.empty_cache()
# output_texts.extend(tokenizer.batch_decode(output))

# with multiprocessing.Pool(processes=config["num_processes"]) as pol:
# for entry in pol.map(create_molecule_entry, output_texts) if entry]:
# entry.score = oracle(entry.smiles)
# generated_entries.append(entry)
# if oracle.finish or len(generated_entries) >= config["num_gens_per_iter"]:
# break

# if oracle.finish:
# break

# num_iter += 1
# if oracle.finish:
# break
# molecule_pool.add(generated_entries)
96 changes: 96 additions & 0 deletions chemlactica/mol_opt/oracle_estimators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import List
import time
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import AutoModelForCausalLM, PreTrainedModel, AutoModel, AutoConfig, AutoTokenizer
import torch
import torch.nn as nn
import numpy as np
from chemlactica.mol_opt.utils import MoleculeEntry
from sklearn.linear_model import Ridge


def find_second_eos_token_indices(sequences, eos_token_id):
return torch.where(sequences[:, 1:] == eos_token_id)


def init_linear_layer(layer, emb_length):
torch.nn.init.normal_(
layer.weight,
mean=0.0, std=1 / np.sqrt(emb_length + 1)
)
torch.nn.init.constant_(layer.bias, val=0.0)
return layer


class ScalarHeadLM(PreTrainedModel):

def __init__(self, config):
super().__init__(config)
self.config = config
self.lm_backbone = AutoModel.from_pretrained(
config._name_or_path,
config=config
)
self.scalar_head = nn.Linear(config.hidden_size, 1)
init_linear_layer(self.scalar_head)

def forward(self, **kwargs):
output = self.lm_backbone(**kwargs)
return self.scalar_head(output.last_hidden_state)


class LinearFingerprintModel:

def __init__(self):
self.emb_length = 2048
self.linear = Ridge()
self.all_entries = []
self.is_fit = False

def __call__(self, mol_entries: List[MoleculeEntry]):
mol_embs = np.array([entry.fingerprint for entry in mol_entries])
return self.linear.predict(mol_embs)

def fit(self, mol_entries: List[MoleculeEntry]):
self.is_fit = True
start_time = time.time()
self.all_entries.extend(mol_entries)
mol_embs = np.array([entry.fingerprint for entry in self.all_entries])
scores = np.array([entry.score for entry in self.all_entries])
self.linear.fit(mol_embs, scores)
print(f"Fit time {time.time() - start_time:.4f}s")


class ScalarOracleApproximator:

def __init__(self, config, tokenizer):
self.scalar_head_lm = ScalarHeadLM(config)
self.tokenizer = tokenizer

def __call__(self, mol_entries):
prompts = [f"</s>[START_SMILES]{e.smiles}[END_SMILES]</s>" for e in mol_entries]
data = self.tokenizer(prompts, return_tensors="pt", padding=True).to(self.scalar_head_lm.device)
del data["token_type_ids"]
outputs = self.scalar_head_lm(
**data
)
print(outputs)


class SFTOracleApproximator:

def __init__(self, config, tokenizer, device):
self.ml = AutoModelForCausalLM.from_pretrained(
config._name_or_path,
config=config
).to(device)
self.tokenizer = tokenizer


if __name__ == "__main__":
config = AutoConfig.from_pretrained("/nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/26d322857a184fcbafda5d4a/checkpoint-118784")
tokenizer = AutoTokenizer.from_pretrained("chemlactica/tokenizer/ChemLacticaTokenizer66", padding_side="left")
scalar_oracle_approx = ScalarOracleApproximator(config, tokenizer)

mol_entries = [MoleculeEntry("CCC" + i * "C") for i in range(10)]
scalar_oracle_approx(mol_entries)
50 changes: 50 additions & 0 deletions chemlactica/mol_opt/tunning.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like this piece is redundant as we already have code for supervised fine tuning. Is there a way we can use the existing code base?
also it looks like this code is using the old config format, let's update this if possible as this won't match with our existing code base anymore.

Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import TrainingArguments, get_polynomial_decay_schedule_with_warmup
from torch.optim.lr_scheduler import ConstantLR
import torch


def supervised_fine_tune(
model, tokenizer,
train_dataset, config
):
model.train()
training_args = TrainingArguments(
output_dir=config["checkpoints_dir"],
per_device_train_batch_size=config["train_batch_size"],
max_grad_norm=config["global_gradient_norm"],
num_train_epochs=config["num_train_epochs"],
evaluation_strategy="no",
dataloader_drop_last=False,
dataloader_pin_memory=True,
dataloader_num_workers=config["dataloader_num_workers"],
logging_steps=1
)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config["max_learning_rate"],
betas=[config["adam_beta1"], config["adam_beta2"]],
weight_decay=config["weight_decay"],
)
lr_scheduler = get_polynomial_decay_schedule_with_warmup(
optimizer,
num_warmup_steps=config["warmup_steps"],
num_training_steps=config["num_train_epochs"] * (len(train_dataset) // config["train_batch_size"] + 1),
lr_end=0.999 * config["max_learning_rate"],
power=1.0,
)
collator = DataCollatorForCompletionOnlyLM(
config["response_template"], tokenizer=tokenizer
)
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
formatting_func=config["formatting_func"],
args=training_args,
packing=config["packing"],
tokenizer=tokenizer,
max_seq_length=config["max_seq_length"],
data_collator=collator,
optimizers=[optimizer, lr_scheduler]
)
trainer.train()
Loading
Loading