-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mol opt #24
Mol opt #24
Changes from 3 commits
9587e84
2a9cc8c
ab876f9
56e9232
de7a695
40f7dde
7559e5b
a023281
8ae3162
f2ee468
44cfb4b
898f5c1
a68b4f2
f5e1ca2
52b3a64
ca32500
cff4db7
efd8106
cade253
70642da
a805c1b
20d3fe0
bfa32d7
d51cbfb
1502e65
188f384
d95e4d4
800278b
1306842
d9f85d1
b98a2cb
32cc5aa
0ac3a7b
d8ac65f
05af82d
16a00ca
832906c
1f8142f
91d9390
2cf0f8b
9c96919
5d7a258
be4096d
a84c79c
c393b1a
5537720
10feca8
52692ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import torch | ||
from datasets import Dataset | ||
import multiprocessing | ||
import gc | ||
import re | ||
import numpy as np | ||
from collections import namedtuple | ||
from chemlactica.mol_opt.utils import MoleculeEntry, MoleculePool, generate_random_number | ||
from chemlactica.mol_opt.tunning import supervised_fine_tune | ||
|
||
|
||
def create_optimization_prompts(num_prompts, molecule_pool, max_similars_in_prompt: int, sim_range): | ||
prompts = [] | ||
for i in range(num_prompts): | ||
similars_in_prompt = molecule_pool.random_subset(max_similars_in_prompt) | ||
prompt = "</s>" | ||
for mol in similars_in_prompt: | ||
prompt += f"[SIMILAR]{mol.smiles} {generate_random_number(sim_range[0], sim_range[1]):.2f}[/SIMILAR]" | ||
prompt += "[START_SMILES]" | ||
prompts.append(prompt) | ||
return prompts | ||
|
||
|
||
def create_molecule_entry(output_text): | ||
start_smiles_tag, end_smiles_tag = "[START_SMILES]", "[END_SMILES]" | ||
start_ind = output_text.find(start_smiles_tag) | ||
end_ind = output_text.find(end_smiles_tag) | ||
if start_ind == -1 or end_ind == -1: | ||
return None | ||
|
||
generated_smiles = output_text[start_ind+len(start_smiles_tag):end_ind] | ||
for similar in output_text.split("[SIMILAR]")[1:-1]: | ||
similar_smiles = similar.split(" ")[0] | ||
if generated_smiles == similar_smiles: | ||
return None | ||
try: | ||
return MoleculeEntry( | ||
smiles=generated_smiles, | ||
) | ||
except: | ||
return None | ||
|
||
|
||
def query_molecule_properties(model, tokenizer, smiles, property_tag, prop_pred_kwargs): | ||
property_start_tag, property_end_tag = f"[{property_tag}]", f"[/{property_tag}]" | ||
prompts = [f"</s>[START_SMILES]{smiles}[END_SMILES][{property_tag}]"] | ||
data = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) | ||
del data["token_type_ids"] | ||
outputs = model.generate( | ||
**data, | ||
**prop_pred_kwargs | ||
) | ||
predicted_property_values = [] | ||
output_texts = tokenizer.batch_decode(outputs) | ||
for output_text in output_texts: | ||
start_ind = output_text.find(property_start_tag) | ||
end_ind = output_text.find(property_end_tag) | ||
if start_ind != -1 and end_ind != -1: | ||
predicted_property_values.append(output_text[start_ind+len(property_start_tag):end_ind]) | ||
else: | ||
predicted_property_values.append(None) | ||
return predicted_property_values | ||
|
||
|
||
def optimize( | ||
model, tokenizer, | ||
oracle, config | ||
): | ||
file = open(config["log_dir"], "w") | ||
print("config", config) | ||
# print("molecule generation arguments", config["generation_config"]) | ||
molecule_pool = MoleculePool(config["molecule_pool_size"]) | ||
|
||
if config["strategy"] == "rej-sample": | ||
training_entries = [] | ||
|
||
num_iter = 1 | ||
while True: | ||
model.eval() | ||
current_entries = [] | ||
while len(current_entries) < config["num_gens_per_iter"]: | ||
prompts = create_optimization_prompts( | ||
config["num_gens_per_iter"], molecule_pool, | ||
max_similars_in_prompt=config["max_similars_in_prompt"], | ||
sim_range=config["sim_range"] | ||
) | ||
output_texts = [] | ||
generation_batch_size = 100 | ||
for i in range(0, len(prompts), generation_batch_size): | ||
prompt_batch = prompts[i: min(len(prompts), i + generation_batch_size)] | ||
data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device) | ||
del data["token_type_ids"] | ||
output = model.generate( | ||
**data, | ||
**config["generation_config"] | ||
) | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
output_texts.extend(tokenizer.batch_decode(output)) | ||
|
||
with multiprocessing.Pool(processes=config["num_processes"]) as pol: | ||
for i, entry in enumerate(pol.map(create_molecule_entry, output_texts)): | ||
if entry: | ||
entry.score = oracle(entry.smiles) | ||
entry.additional_properties["prompt"] = prompts[i] | ||
current_entries.append(entry) | ||
file.write(f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n") | ||
if oracle.finish or len(current_entries) >= config["num_gens_per_iter"]: | ||
break | ||
|
||
if oracle.finish: | ||
break | ||
|
||
current_entries = list(np.unique(current_entries))[::-1] | ||
num_iter += 1 | ||
if oracle.finish: | ||
break | ||
|
||
if config["strategy"] == "rej-sample": | ||
top_k = int(len(current_entries) * config["rej_sample_config"]["rej_perc"]) | ||
training_entries.extend(current_entries[:top_k]) | ||
training_entries = list(np.unique(training_entries))[::-1] | ||
if len(training_entries) >= config["rej_sample_config"]["num_samples_per_round"]: | ||
print(f"Num of train examples {len(training_entries)}.") | ||
file.write("Training entries\n") | ||
for i, mol in enumerate(training_entries): | ||
file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") | ||
train_dataset = Dataset.from_dict({ | ||
"sample": [ | ||
f"{entry.additional_properties['prompt']}{entry.smiles}[END_SMILES]</s>" | ||
for entry in training_entries | ||
] | ||
}) | ||
config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"] | ||
supervised_fine_tune(model, tokenizer, train_dataset, config["rej_sample_config"]) | ||
training_entries = [] | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
|
||
molecule_pool.add(current_entries) | ||
file.write("Molecule pool\n") | ||
for i, mol in enumerate(molecule_pool.molecule_entries): | ||
file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") | ||
|
||
|
||
# def optimize_reinvent( | ||
# model, prior_model, | ||
# tokenizer, oracle, | ||
# config | ||
# ): | ||
# print("molecule pool size", config["molecule_pool_size"]) | ||
# print("molecule generation arguments", config["mol_gen_kwargs"]) | ||
# molecule_pool = MoleculePool(config["molecule_pool_size"]) | ||
|
||
# num_iter = 1 | ||
# while True: | ||
# generated_entries = [] | ||
# while len(generated_entries) < config["num_gens_per_iter"]: | ||
# prompts = create_optimization_prompts( | ||
# config["num_gens_per_iter"], molecule_pool, | ||
# max_similars_in_prompt=config["max_similars_in_prompt"], | ||
# sim_range=config["sim_range"] | ||
# ) | ||
# output_texts = [] | ||
# generation_batch_size = 200 | ||
# for i in range(0, len(prompts), generation_batch_size): | ||
# prompt_batch = prompts[i: min(len(prompts), i + generation_batch_size)] | ||
# data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device) | ||
# del data["token_type_ids"] | ||
# output = model.generate( | ||
# **data, | ||
# **config["mol_gen_kwargs"] | ||
# ) | ||
# gc.collect() | ||
# torch.cuda.empty_cache() | ||
# output_texts.extend(tokenizer.batch_decode(output)) | ||
|
||
# with multiprocessing.Pool(processes=config["num_processes"]) as pol: | ||
# for entry in pol.map(create_molecule_entry, output_texts) if entry]: | ||
# entry.score = oracle(entry.smiles) | ||
# generated_entries.append(entry) | ||
# if oracle.finish or len(generated_entries) >= config["num_gens_per_iter"]: | ||
# break | ||
|
||
# if oracle.finish: | ||
# break | ||
|
||
# num_iter += 1 | ||
# if oracle.finish: | ||
# break | ||
# molecule_pool.add(generated_entries) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from typing import List | ||
import time | ||
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM | ||
from transformers import AutoModelForCausalLM, PreTrainedModel, AutoModel, AutoConfig, AutoTokenizer | ||
import torch | ||
import torch.nn as nn | ||
import numpy as np | ||
from chemlactica.mol_opt.utils import MoleculeEntry | ||
from sklearn.linear_model import Ridge | ||
|
||
|
||
def find_second_eos_token_indices(sequences, eos_token_id): | ||
return torch.where(sequences[:, 1:] == eos_token_id) | ||
|
||
|
||
def init_linear_layer(layer, emb_length): | ||
torch.nn.init.normal_( | ||
layer.weight, | ||
mean=0.0, std=1 / np.sqrt(emb_length + 1) | ||
) | ||
torch.nn.init.constant_(layer.bias, val=0.0) | ||
return layer | ||
|
||
|
||
class ScalarHeadLM(PreTrainedModel): | ||
|
||
def __init__(self, config): | ||
super().__init__(config) | ||
self.config = config | ||
self.lm_backbone = AutoModel.from_pretrained( | ||
config._name_or_path, | ||
config=config | ||
) | ||
self.scalar_head = nn.Linear(config.hidden_size, 1) | ||
init_linear_layer(self.scalar_head) | ||
|
||
def forward(self, **kwargs): | ||
output = self.lm_backbone(**kwargs) | ||
return self.scalar_head(output.last_hidden_state) | ||
|
||
|
||
class LinearFingerprintModel: | ||
|
||
def __init__(self): | ||
self.emb_length = 2048 | ||
self.linear = Ridge() | ||
self.all_entries = [] | ||
self.is_fit = False | ||
|
||
def __call__(self, mol_entries: List[MoleculeEntry]): | ||
mol_embs = np.array([entry.fingerprint for entry in mol_entries]) | ||
return self.linear.predict(mol_embs) | ||
|
||
def fit(self, mol_entries: List[MoleculeEntry]): | ||
self.is_fit = True | ||
start_time = time.time() | ||
self.all_entries.extend(mol_entries) | ||
mol_embs = np.array([entry.fingerprint for entry in self.all_entries]) | ||
scores = np.array([entry.score for entry in self.all_entries]) | ||
self.linear.fit(mol_embs, scores) | ||
print(f"Fit time {time.time() - start_time:.4f}s") | ||
|
||
|
||
class ScalarOracleApproximator: | ||
|
||
def __init__(self, config, tokenizer): | ||
self.scalar_head_lm = ScalarHeadLM(config) | ||
self.tokenizer = tokenizer | ||
|
||
def __call__(self, mol_entries): | ||
prompts = [f"</s>[START_SMILES]{e.smiles}[END_SMILES]</s>" for e in mol_entries] | ||
data = self.tokenizer(prompts, return_tensors="pt", padding=True).to(self.scalar_head_lm.device) | ||
del data["token_type_ids"] | ||
outputs = self.scalar_head_lm( | ||
**data | ||
) | ||
print(outputs) | ||
|
||
|
||
class SFTOracleApproximator: | ||
|
||
def __init__(self, config, tokenizer, device): | ||
self.ml = AutoModelForCausalLM.from_pretrained( | ||
config._name_or_path, | ||
config=config | ||
).to(device) | ||
self.tokenizer = tokenizer | ||
|
||
|
||
if __name__ == "__main__": | ||
config = AutoConfig.from_pretrained("/nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/26d322857a184fcbafda5d4a/checkpoint-118784") | ||
tokenizer = AutoTokenizer.from_pretrained("chemlactica/tokenizer/ChemLacticaTokenizer66", padding_side="left") | ||
scalar_oracle_approx = ScalarOracleApproximator(config, tokenizer) | ||
|
||
mol_entries = [MoleculeEntry("CCC" + i * "C") for i in range(10)] | ||
scalar_oracle_approx(mol_entries) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems like this piece is redundant as we already have code for supervised fine tuning. Is there a way we can use the existing code base? |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM | ||
from transformers import TrainingArguments, get_polynomial_decay_schedule_with_warmup | ||
from torch.optim.lr_scheduler import ConstantLR | ||
import torch | ||
|
||
|
||
def supervised_fine_tune( | ||
model, tokenizer, | ||
train_dataset, config | ||
): | ||
model.train() | ||
training_args = TrainingArguments( | ||
output_dir=config["checkpoints_dir"], | ||
per_device_train_batch_size=config["train_batch_size"], | ||
max_grad_norm=config["global_gradient_norm"], | ||
num_train_epochs=config["num_train_epochs"], | ||
evaluation_strategy="no", | ||
dataloader_drop_last=False, | ||
dataloader_pin_memory=True, | ||
dataloader_num_workers=config["dataloader_num_workers"], | ||
logging_steps=1 | ||
) | ||
optimizer = torch.optim.AdamW( | ||
model.parameters(), | ||
lr=config["max_learning_rate"], | ||
betas=[config["adam_beta1"], config["adam_beta2"]], | ||
weight_decay=config["weight_decay"], | ||
) | ||
lr_scheduler = get_polynomial_decay_schedule_with_warmup( | ||
optimizer, | ||
num_warmup_steps=config["warmup_steps"], | ||
num_training_steps=config["num_train_epochs"] * (len(train_dataset) // config["train_batch_size"] + 1), | ||
lr_end=0.999 * config["max_learning_rate"], | ||
power=1.0, | ||
) | ||
collator = DataCollatorForCompletionOnlyLM( | ||
config["response_template"], tokenizer=tokenizer | ||
) | ||
trainer = SFTTrainer( | ||
model=model, | ||
train_dataset=train_dataset, | ||
formatting_func=config["formatting_func"], | ||
args=training_args, | ||
packing=config["packing"], | ||
tokenizer=tokenizer, | ||
max_seq_length=config["max_seq_length"], | ||
data_collator=collator, | ||
optimizers=[optimizer, lr_scheduler] | ||
) | ||
trainer.train() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is not clear where we call this optimize function and what the input model and config are. please push the code that runs this module.