From 9587e847c317aed0361623e0ee19f6c6d912b756 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Fri, 19 Apr 2024 14:20:51 +0400 Subject: [PATCH 01/45] add mol opt --- chemlactica/mol_opt/__init__.py | 0 chemlactica/mol_opt/optimization.py | 171 +++++++++++++++++++++++ chemlactica/mol_opt/oracle_estimators.py | 96 +++++++++++++ chemlactica/mol_opt/utils.py | 98 +++++++++++++ 4 files changed, 365 insertions(+) create mode 100644 chemlactica/mol_opt/__init__.py create mode 100644 chemlactica/mol_opt/optimization.py create mode 100644 chemlactica/mol_opt/oracle_estimators.py create mode 100644 chemlactica/mol_opt/utils.py diff --git a/chemlactica/mol_opt/__init__.py b/chemlactica/mol_opt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py new file mode 100644 index 0000000..9c8f0b3 --- /dev/null +++ b/chemlactica/mol_opt/optimization.py @@ -0,0 +1,171 @@ +import torch +from transformers import OPTForCausalLM, AutoTokenizer +import multiprocessing +import gc +from collections import namedtuple +from chemlactica.mol_opt.utils import MoleculeEntry, MoleculePool, generate_random_number + + +def create_optimization_prompts(num_prompts, molecule_pool, max_similars_in_prompt: int, sim_range): + prompts = [] + for i in range(num_prompts): + similars_in_prompt = molecule_pool.random_subset(max_similars_in_prompt) + prompt = "" + for mol in similars_in_prompt: + prompt += f"[SIMILAR]{mol.smiles} {generate_random_number(sim_range[0], sim_range[1]):.2f}[/SIMILAR]" + prompt += "[START_SMILES]" + prompts.append(prompt) + return prompts + + +def create_molecule_entry(output_text): + start_smiles_tag, end_smiles_tag = "[START_SMILES]", "[END_SMILES]" + start_ind = output_text.find(start_smiles_tag) + end_ind = output_text.find(end_smiles_tag) + if start_ind == -1 or end_ind == -1: + return None + + generated_smiles = output_text[start_ind+len(start_smiles_tag):end_ind] + try: + return MoleculeEntry( + smiles=generated_smiles, + ) + except: + return None + + +def query_molecule_properties(model, tokenizer, smiles, property_tag, prop_pred_kwargs): + property_start_tag, property_end_tag = f"[{property_tag}]", f"[/{property_tag}]" + prompts = [f"[START_SMILES]{smiles}[END_SMILES][{property_tag}]"] + data = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + del data["token_type_ids"] + outputs = model.generate( + **data, + **prop_pred_kwargs + ) + predicted_property_values = [] + output_texts = tokenizer.batch_decode(outputs) + for output_text in output_texts: + start_ind = output_text.find(property_start_tag) + end_ind = output_text.find(property_end_tag) + if start_ind != -1 and end_ind != -1: + predicted_property_values.append(output_text[start_ind+len(property_start_tag):end_ind]) + else: + predicted_property_values.append(None) + return predicted_property_values + + +def optimize( + model, tokenizer, + oracle, oracle_estimator, + config + ): + print("molecule pool size", config["molecule_pool_size"]) + print("molecule generation arguments", config["mol_gen_kwargs"]) + molecule_pool = MoleculePool(config["molecule_pool_size"]) + + num_iter = 1 + while True: + generated_entries = [] + oracle_estimator_error = 0 + while len(generated_entries) < config["num_gens_per_iter"]: + prompts = create_optimization_prompts( + config["num_gens_per_iter"], molecule_pool, + max_similars_in_prompt=config["max_similars_in_prompt"], + sim_range=config["sim_range"] + ) + output_texts = [] + generation_batch_size = 200 + for i in range(0, len(prompts), generation_batch_size): + prompt_batch = prompts[i: min(len(prompts), i + generation_batch_size)] + data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device) + del data["token_type_ids"] + output = model.generate( + **data, + **config["mol_gen_kwargs"] + ) + gc.collect() + torch.cuda.empty_cache() + output_texts.extend(tokenizer.batch_decode(output)) + + candidate_entries = [] + with multiprocessing.Pool(processes=config["num_processes"]) as pol: + candidate_entries.extend([entry for entry in pol.map(create_molecule_entry, output_texts) if entry]) + + # take top-k using oracle estimator + top_k = len(candidate_entries) + if num_iter != 1 and oracle_estimator: + score_estimates = oracle_estimator(candidate_entries) + for score_est, entry in zip(score_estimates, candidate_entries): + entry.score_estimate = score_est + candidate_entries.sort(key=lambda x: x.score_estimate, reverse=True) + top_k //= 4 + + for entry in candidate_entries[:top_k]: + entry.score = oracle(entry.smiles) + generated_entries.append(entry) + if oracle_estimator and entry.score_estimate: + oracle_estimator_error += abs(entry.score - entry.score_estimate) + if oracle.finish or len(generated_entries) >= config["num_gens_per_iter"]: + break + + if oracle.finish: + break + + num_iter += 1 + if oracle_estimator: + oracle_estimator_error = oracle_estimator_error / len(generated_entries) + print(f"Oracle estimate mean absolute error: {oracle_estimator_error:.4f}") + if not oracle_estimator.is_fit or oracle_estimator_error > 0.1: + oracle_estimator.fit(generated_entries) + if oracle.finish: + break + molecule_pool.add(generated_entries) + + +# def optimize_reinvent( +# model, prior_model, +# tokenizer, oracle, +# config +# ): +# print("molecule pool size", config["molecule_pool_size"]) +# print("molecule generation arguments", config["mol_gen_kwargs"]) +# molecule_pool = MoleculePool(config["molecule_pool_size"]) + +# num_iter = 1 +# while True: +# generated_entries = [] +# while len(generated_entries) < config["num_gens_per_iter"]: +# prompts = create_optimization_prompts( +# config["num_gens_per_iter"], molecule_pool, +# max_similars_in_prompt=config["max_similars_in_prompt"], +# sim_range=config["sim_range"] +# ) +# output_texts = [] +# generation_batch_size = 200 +# for i in range(0, len(prompts), generation_batch_size): +# prompt_batch = prompts[i: min(len(prompts), i + generation_batch_size)] +# data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device) +# del data["token_type_ids"] +# output = model.generate( +# **data, +# **config["mol_gen_kwargs"] +# ) +# gc.collect() +# torch.cuda.empty_cache() +# output_texts.extend(tokenizer.batch_decode(output)) + +# with multiprocessing.Pool(processes=config["num_processes"]) as pol: +# for entry in pol.map(create_molecule_entry, output_texts) if entry]: +# entry.score = oracle(entry.smiles) +# generated_entries.append(entry) +# if oracle.finish or len(generated_entries) >= config["num_gens_per_iter"]: +# break + +# if oracle.finish: +# break + +# num_iter += 1 +# if oracle.finish: +# break +# molecule_pool.add(generated_entries) \ No newline at end of file diff --git a/chemlactica/mol_opt/oracle_estimators.py b/chemlactica/mol_opt/oracle_estimators.py new file mode 100644 index 0000000..3d64477 --- /dev/null +++ b/chemlactica/mol_opt/oracle_estimators.py @@ -0,0 +1,96 @@ +from typing import List +import time +from trl import SFTTrainer, DataCollatorForCompletionOnlyLM +from transformers import AutoModelForCausalLM, PreTrainedModel, AutoModel, AutoConfig, AutoTokenizer +import torch +import torch.nn as nn +import numpy as np +from chemlactica.mol_opt.utils import MoleculeEntry +from sklearn.linear_model import Ridge + + +def find_second_eos_token_indices(sequences, eos_token_id): + return torch.where(sequences[:, 1:] == eos_token_id) + + +def init_linear_layer(layer, emb_length): + torch.nn.init.normal_( + layer.weight, + mean=0.0, std=1 / np.sqrt(emb_length + 1) + ) + torch.nn.init.constant_(layer.bias, val=0.0) + return layer + + +class ScalarHeadLM(PreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.config = config + self.lm_backbone = AutoModel.from_pretrained( + config._name_or_path, + config=config + ) + self.scalar_head = nn.Linear(config.hidden_size, 1) + init_linear_layer(self.scalar_head) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + return self.scalar_head(output.last_hidden_state) + + +class LinearFingerprintModel: + + def __init__(self): + self.emb_length = 2048 + self.linear = Ridge() + self.all_entries = [] + self.is_fit = False + + def __call__(self, mol_entries: List[MoleculeEntry]): + mol_embs = np.array([entry.fingerprint for entry in mol_entries]) + return self.linear.predict(mol_embs) + + def fit(self, mol_entries: List[MoleculeEntry]): + self.is_fit = True + start_time = time.time() + self.all_entries.extend(mol_entries) + mol_embs = np.array([entry.fingerprint for entry in self.all_entries]) + scores = np.array([entry.score for entry in self.all_entries]) + self.linear.fit(mol_embs, scores) + print(f"Fit time {time.time() - start_time:.4f}s") + + +class ScalarOracleApproximator: + + def __init__(self, config, tokenizer): + self.scalar_head_lm = ScalarHeadLM(config) + self.tokenizer = tokenizer + + def __call__(self, mol_entries): + prompts = [f"[START_SMILES]{e.smiles}[END_SMILES]" for e in mol_entries] + data = self.tokenizer(prompts, return_tensors="pt", padding=True).to(self.scalar_head_lm.device) + del data["token_type_ids"] + outputs = self.scalar_head_lm( + **data + ) + print(outputs) + + +class SFTOracleApproximator: + + def __init__(self, config, tokenizer, device): + self.ml = AutoModelForCausalLM.from_pretrained( + config._name_or_path, + config=config + ).to(device) + self.tokenizer = tokenizer + + +if __name__ == "__main__": + config = AutoConfig.from_pretrained("/nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/26d322857a184fcbafda5d4a/checkpoint-118784") + tokenizer = AutoTokenizer.from_pretrained("chemlactica/tokenizer/ChemLacticaTokenizer66", padding_side="left") + scalar_oracle_approx = ScalarOracleApproximator(config, tokenizer) + + mol_entries = [MoleculeEntry("CCC" + i * "C") for i in range(10)] + scalar_oracle_approx(mol_entries) \ No newline at end of file diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py new file mode 100644 index 0000000..c07b09f --- /dev/null +++ b/chemlactica/mol_opt/utils.py @@ -0,0 +1,98 @@ +from typing import List +import random +from pathlib import Path +import numpy as np +import torch +from rdkit import Chem, DataStructs, RDLogger +from rdkit.Chem import AllChem, MACCSkeys +from rdkit.Chem.QED import qed + +# Disable RDKit logs +RDLogger.DisableLog('rdApp.*') + + +def set_seed(seed_value): + random.seed(seed_value) + # Set seed for NumPy + np.random.seed(seed_value) + # Set seed for PyTorch + torch.manual_seed(seed_value) + + +def get_short_name_for_ckpt_path(chpt_path: str, hash_len: int=6): + get_short_name_for_ckpt_path = Path(chpt_path) + return get_short_name_for_ckpt_path.parent.name[:hash_len] + '-' + get_short_name_for_ckpt_path.name.split("-")[-1] + + +def get_morgan_fingerprint(mol): + return AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) + + +def get_maccs_fingerprint(mol): + return MACCSkeys.GenMACCSKeys(mol) + + +def tanimoto_dist_func(mol1, mol2, fingerprint: str="morgan"): + return DataStructs.TanimotoSimilarity( + get_morgan_fingerprint(mol1) if fingerprint == 'morgan' else get_maccs_fingerprint(mol1), + get_morgan_fingerprint(mol2) if fingerprint == 'morgan' else get_maccs_fingerprint(mol2), + ) + + +def generate_random_number(lower, upper): + return lower + random.random() * (upper - lower) + + +def canonicalize(smiles): + smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), canonical=True) + return Chem.MolToSmiles(Chem.MolFromSmiles(smiles), kekuleSmiles=True) + + +class MoleculeEntry: + + def __init__(self, smiles, score=None, score_estimate=None, **kwargs): + self.smiles = canonicalize(smiles) + self.mol = Chem.MolFromSmiles(smiles) + self.inchi = Chem.MolToInchi(self.mol) + self.fingerprint = get_morgan_fingerprint(self.mol) + self.score = score + self.score_estimate = score_estimate + self.additional_properties = kwargs + + def __eq__(self, other): + return self.inchi == other.inchi + + def __lt__(self, other): + return self.score < other.score + + def __str__(self): + return f"smiles: {self.smiles}, " \ + f"score: {round(self.score, 4) if self.score else 'none'}, " \ + f"score_estimate: {round(self.score_estimate, 4) if self.score_estimate else 'none'}" + + def __repr__(self): + return str(self) + + +class MoleculePool: + + def __init__(self, size): + self.size = size + self.molecule_entries: List[MoleculeEntry] = [] + + def add(self, entries: List[MoleculeEntry]): + assert type(entries) == list + self.molecule_entries.extend(entries) + self.molecule_entries.sort(reverse=True) + + # remove doublicates + new_molecule_list = [] + for mol in self.molecule_entries: + if len(new_molecule_list) == 0 or new_molecule_list[-1] != mol: + new_molecule_list.append(mol) + + self.molecule_entries = new_molecule_list[:min(len(new_molecule_list), self.size)] + + def random_subset(self, subset_size): + rand_inds = np.random.permutation(min(len(self.molecule_entries), subset_size)) + return [self.molecule_entries[i] for i in rand_inds] \ No newline at end of file From 2a9cc8c4d47ef476d313c99a6342d7cc240cde14 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Mon, 22 Apr 2024 08:32:43 +0400 Subject: [PATCH 02/45] mol opt+rej sample --- chemlactica/mol_opt/optimization.py | 90 ++++++++++++++++++----------- chemlactica/mol_opt/tunning.py | 50 ++++++++++++++++ chemlactica/mol_opt/utils.py | 32 +++++----- 3 files changed, 124 insertions(+), 48 deletions(-) create mode 100644 chemlactica/mol_opt/tunning.py diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 9c8f0b3..7e021f1 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -1,9 +1,12 @@ import torch -from transformers import OPTForCausalLM, AutoTokenizer +from datasets import Dataset import multiprocessing import gc +import re +import numpy as np from collections import namedtuple from chemlactica.mol_opt.utils import MoleculeEntry, MoleculePool, generate_random_number +from chemlactica.mol_opt.tunning import supervised_fine_tune def create_optimization_prompts(num_prompts, molecule_pool, max_similars_in_prompt: int, sim_range): @@ -26,6 +29,10 @@ def create_molecule_entry(output_text): return None generated_smiles = output_text[start_ind+len(start_smiles_tag):end_ind] + for similar in output_text.split("[SIMILAR]")[1:-1]: + similar_smiles = similar.split(" ")[0] + if generated_smiles == similar_smiles: + return None try: return MoleculeEntry( smiles=generated_smiles, @@ -57,70 +64,83 @@ def query_molecule_properties(model, tokenizer, smiles, property_tag, prop_pred_ def optimize( model, tokenizer, - oracle, oracle_estimator, - config + oracle, config ): - print("molecule pool size", config["molecule_pool_size"]) - print("molecule generation arguments", config["mol_gen_kwargs"]) + file = open(config["log_dir"], "w") + print("config", config) + # print("molecule generation arguments", config["generation_config"]) molecule_pool = MoleculePool(config["molecule_pool_size"]) + if config["strategy"] == "rej_sample": + training_entries = [] + num_iter = 1 while True: - generated_entries = [] - oracle_estimator_error = 0 - while len(generated_entries) < config["num_gens_per_iter"]: + model.eval() + current_entries = [] + while len(current_entries) < config["num_gens_per_iter"]: prompts = create_optimization_prompts( config["num_gens_per_iter"], molecule_pool, max_similars_in_prompt=config["max_similars_in_prompt"], sim_range=config["sim_range"] ) output_texts = [] - generation_batch_size = 200 + generation_batch_size = 100 for i in range(0, len(prompts), generation_batch_size): prompt_batch = prompts[i: min(len(prompts), i + generation_batch_size)] data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device) del data["token_type_ids"] output = model.generate( **data, - **config["mol_gen_kwargs"] + **config["generation_config"] ) gc.collect() torch.cuda.empty_cache() output_texts.extend(tokenizer.batch_decode(output)) - candidate_entries = [] with multiprocessing.Pool(processes=config["num_processes"]) as pol: - candidate_entries.extend([entry for entry in pol.map(create_molecule_entry, output_texts) if entry]) - - # take top-k using oracle estimator - top_k = len(candidate_entries) - if num_iter != 1 and oracle_estimator: - score_estimates = oracle_estimator(candidate_entries) - for score_est, entry in zip(score_estimates, candidate_entries): - entry.score_estimate = score_est - candidate_entries.sort(key=lambda x: x.score_estimate, reverse=True) - top_k //= 4 - - for entry in candidate_entries[:top_k]: - entry.score = oracle(entry.smiles) - generated_entries.append(entry) - if oracle_estimator and entry.score_estimate: - oracle_estimator_error += abs(entry.score - entry.score_estimate) - if oracle.finish or len(generated_entries) >= config["num_gens_per_iter"]: - break + for i, entry in enumerate(pol.map(create_molecule_entry, output_texts)): + if entry: + entry.score = oracle(entry.smiles) + entry.additional_properties["prompt"] = prompts[i] + current_entries.append(entry) + file.write(f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n") + if oracle.finish or len(current_entries) >= config["num_gens_per_iter"]: + break if oracle.finish: break + current_entries = list(np.unique(current_entries))[::-1] num_iter += 1 - if oracle_estimator: - oracle_estimator_error = oracle_estimator_error / len(generated_entries) - print(f"Oracle estimate mean absolute error: {oracle_estimator_error:.4f}") - if not oracle_estimator.is_fit or oracle_estimator_error > 0.1: - oracle_estimator.fit(generated_entries) if oracle.finish: break - molecule_pool.add(generated_entries) + + if config["strategy"] == "rej_sample": + top_k = int(len(current_entries) * config["rej_sample_config"]["rej_perc"]) + training_entries.extend(current_entries[:top_k]) + training_entries = list(np.unique(training_entries))[::-1] + if len(training_entries) >= config["rej_sample_config"]["num_samples_per_round"]: + print(f"Num of train examples {len(training_entries)}.") + file.write("Training entries") + for i, mol in enumerate(training_entries): + file.write(f"\t{i} smiles {mol.smiles}, score {mol.score:.4f}\n") + train_dataset = Dataset.from_dict({ + "sample": [ + f"{entry.additional_properties['prompt']}{entry.smiles}[END_SMILES]" + for entry in training_entries + ] + }) + config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"] + supervised_fine_tune(model, tokenizer, train_dataset, config["rej_sample_config"]) + training_entries = [] + gc.collect() + torch.cuda.empty_cache() + + molecule_pool.add(current_entries) + file.write("Molecule pool\n") + for i, mol in enumerate(molecule_pool.molecule_entries): + file.write(f"\t{i} smiles {mol.smiles}, score {mol.score:.4f}\n") # def optimize_reinvent( diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py new file mode 100644 index 0000000..4eaaa89 --- /dev/null +++ b/chemlactica/mol_opt/tunning.py @@ -0,0 +1,50 @@ +from trl import SFTTrainer, DataCollatorForCompletionOnlyLM +from transformers import TrainingArguments, get_polynomial_decay_schedule_with_warmup +from torch.optim.lr_scheduler import ConstantLR +import torch + + +def supervised_fine_tune( + model, tokenizer, + train_dataset, config + ): + model.train() + training_args = TrainingArguments( + output_dir=config["checkpoints_dir"], + per_device_train_batch_size=config["train_batch_size"], + max_grad_norm=config["global_gradient_norm"], + num_train_epochs=config["num_train_epochs"], + evaluation_strategy="no", + dataloader_drop_last=False, + dataloader_pin_memory=True, + dataloader_num_workers=config["dataloader_num_workers"], + logging_steps=1 + ) + optimizer = torch.optim.AdamW( + model.parameters(), + lr=config["max_learning_rate"], + betas=[config["adam_beta1"], config["adam_beta2"]], + weight_decay=config["weight_decay"], + ) + lr_scheduler = get_polynomial_decay_schedule_with_warmup( + optimizer, + num_warmup_steps=config["warmup_steps"], + num_training_steps=config["num_train_epochs"] * (len(train_dataset) // config["train_batch_size"] + 1), + lr_end=0.999 * config["max_learning_rate"], + power=1.0, + ) + collator = DataCollatorForCompletionOnlyLM( + config["response_template"], tokenizer=tokenizer + ) + trainer = SFTTrainer( + model=model, + train_dataset=train_dataset, + formatting_func=config["formatting_func"], + args=training_args, + packing=config["packing"], + tokenizer=tokenizer, + max_seq_length=config["max_seq_length"], + data_collator=collator, + optimizers=[optimizer, lr_scheduler] + ) + trainer.train() diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index c07b09f..dca6001 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -32,10 +32,10 @@ def get_maccs_fingerprint(mol): return MACCSkeys.GenMACCSKeys(mol) -def tanimoto_dist_func(mol1, mol2, fingerprint: str="morgan"): +def tanimoto_dist_func(fing1, fing2, fingerprint: str="morgan"): return DataStructs.TanimotoSimilarity( - get_morgan_fingerprint(mol1) if fingerprint == 'morgan' else get_maccs_fingerprint(mol1), - get_morgan_fingerprint(mol2) if fingerprint == 'morgan' else get_maccs_fingerprint(mol2), + fing1 if fingerprint == 'morgan' else fing1, + fing2 if fingerprint == 'morgan' else fing2, ) @@ -44,8 +44,8 @@ def generate_random_number(lower, upper): def canonicalize(smiles): - smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), canonical=True) - return Chem.MolToSmiles(Chem.MolFromSmiles(smiles), kekuleSmiles=True) + return Chem.MolToSmiles(Chem.MolFromSmiles(smiles), canonical=True) + # return Chem.MolToSmiles(Chem.MolFromSmiles(smiles), kekuleSmiles=True) class MoleculeEntry: @@ -53,16 +53,17 @@ class MoleculeEntry: def __init__(self, smiles, score=None, score_estimate=None, **kwargs): self.smiles = canonicalize(smiles) self.mol = Chem.MolFromSmiles(smiles) - self.inchi = Chem.MolToInchi(self.mol) self.fingerprint = get_morgan_fingerprint(self.mol) self.score = score self.score_estimate = score_estimate self.additional_properties = kwargs def __eq__(self, other): - return self.inchi == other.inchi + return self.smiles == other.smiles def __lt__(self, other): + if self.score == other.score: + return self.smiles < other.smiles return self.score < other.score def __str__(self): @@ -72,7 +73,7 @@ def __str__(self): def __repr__(self): return str(self) - + class MoleculePool: @@ -86,12 +87,17 @@ def add(self, entries: List[MoleculeEntry]): self.molecule_entries.sort(reverse=True) # remove doublicates - new_molecule_list = [] + new_molecule_entries = [] for mol in self.molecule_entries: - if len(new_molecule_list) == 0 or new_molecule_list[-1] != mol: - new_molecule_list.append(mol) - - self.molecule_entries = new_molecule_list[:min(len(new_molecule_list), self.size)] + insert = True + for m in new_molecule_entries: + if mol == m or tanimoto_dist_func(mol.fingerprint, m.fingerprint) > 1.0: + insert = False + break + if insert: + new_molecule_entries.append(mol) + + self.molecule_entries = new_molecule_entries[:min(len(new_molecule_entries), self.size)] def random_subset(self, subset_size): rand_inds = np.random.permutation(min(len(self.molecule_entries), subset_size)) From ab876f9c2dd4079b04eb54eef5c3fe1e41b7c018 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Mon, 22 Apr 2024 16:26:17 +0400 Subject: [PATCH 03/45] minor logging changes --- chemlactica/mol_opt/optimization.py | 10 +++++----- chemlactica/mol_opt/utils.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 7e021f1..78128a0 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -71,7 +71,7 @@ def optimize( # print("molecule generation arguments", config["generation_config"]) molecule_pool = MoleculePool(config["molecule_pool_size"]) - if config["strategy"] == "rej_sample": + if config["strategy"] == "rej-sample": training_entries = [] num_iter = 1 @@ -116,15 +116,15 @@ def optimize( if oracle.finish: break - if config["strategy"] == "rej_sample": + if config["strategy"] == "rej-sample": top_k = int(len(current_entries) * config["rej_sample_config"]["rej_perc"]) training_entries.extend(current_entries[:top_k]) training_entries = list(np.unique(training_entries))[::-1] if len(training_entries) >= config["rej_sample_config"]["num_samples_per_round"]: print(f"Num of train examples {len(training_entries)}.") - file.write("Training entries") + file.write("Training entries\n") for i, mol in enumerate(training_entries): - file.write(f"\t{i} smiles {mol.smiles}, score {mol.score:.4f}\n") + file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") train_dataset = Dataset.from_dict({ "sample": [ f"{entry.additional_properties['prompt']}{entry.smiles}[END_SMILES]" @@ -140,7 +140,7 @@ def optimize( molecule_pool.add(current_entries) file.write("Molecule pool\n") for i, mol in enumerate(molecule_pool.molecule_entries): - file.write(f"\t{i} smiles {mol.smiles}, score {mol.score:.4f}\n") + file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") # def optimize_reinvent( diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index dca6001..390dba8 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -68,8 +68,8 @@ def __lt__(self, other): def __str__(self): return f"smiles: {self.smiles}, " \ - f"score: {round(self.score, 4) if self.score else 'none'}, " \ - f"score_estimate: {round(self.score_estimate, 4) if self.score_estimate else 'none'}" + f"score: {round(self.score, 4) if self.score != None else 'none'}, " \ + f"score_estimate: {round(self.score_estimate, 4) if self.score_estimate != None else 'none'}" def __repr__(self): return str(self) From 56e9232573aadbe96e03dbcb61be22828d42ef8e Mon Sep 17 00:00:00 2001 From: tigranfah Date: Wed, 24 Apr 2024 08:47:00 +0400 Subject: [PATCH 04/45] rej-sample-v1 --- chemlactica/mol_opt/optimization.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 78128a0..0100868 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -72,7 +72,7 @@ def optimize( molecule_pool = MoleculePool(config["molecule_pool_size"]) if config["strategy"] == "rej-sample": - training_entries = [] + round_entries = [] num_iter = 1 while True: @@ -117,10 +117,11 @@ def optimize( break if config["strategy"] == "rej-sample": - top_k = int(len(current_entries) * config["rej_sample_config"]["rej_perc"]) - training_entries.extend(current_entries[:top_k]) - training_entries = list(np.unique(training_entries))[::-1] - if len(training_entries) >= config["rej_sample_config"]["num_samples_per_round"]: + round_entries.extend(current_entries) + round_entries = list(np.unique(round_entries))[::-1] + top_k = int(len(round_entries) * config["rej_sample_config"]["rej_perc"]) + if len(round_entries[:top_k]) >= config["rej_sample_config"]["num_samples_per_round"]: + training_entries = round_entries[:top_k] print(f"Num of train examples {len(training_entries)}.") file.write("Training entries\n") for i, mol in enumerate(training_entries): @@ -131,9 +132,10 @@ def optimize( for entry in training_entries ] }) + # train_dataset.shuffle(seed=42) config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"] supervised_fine_tune(model, tokenizer, train_dataset, config["rej_sample_config"]) - training_entries = [] + round_entries = [] gc.collect() torch.cuda.empty_cache() From de7a69529c2e64b0765cfcde405a2dfb5603b797 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Sun, 28 Apr 2024 01:19:10 +0400 Subject: [PATCH 05/45] add molecules pool dump --- chemlactica/mol_opt/optimization.py | 23 +++++++++++++++++++---- chemlactica/mol_opt/tunning.py | 1 + chemlactica/mol_opt/utils.py | 16 +++++++++++++--- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 0100868..75ff9cf 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -2,7 +2,7 @@ from datasets import Dataset import multiprocessing import gc -import re +import math import numpy as np from collections import namedtuple from chemlactica.mol_opt.utils import MoleculeEntry, MoleculePool, generate_random_number @@ -74,6 +74,8 @@ def optimize( if config["strategy"] == "rej-sample": round_entries = [] + max_score = 0 + tol_level = 0 num_iter = 1 while True: model.eval() @@ -85,11 +87,13 @@ def optimize( sim_range=config["sim_range"] ) output_texts = [] - generation_batch_size = 100 + generation_batch_size = 64 for i in range(0, len(prompts), generation_batch_size): prompt_batch = prompts[i: min(len(prompts), i + generation_batch_size)] data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device) del data["token_type_ids"] + for key, value in data.items(): + data[key] = value[:, -2048 + config["generation_config"]["max_new_tokens"]:] output = model.generate( **data, **config["generation_config"] @@ -105,6 +109,9 @@ def optimize( entry.additional_properties["prompt"] = prompts[i] current_entries.append(entry) file.write(f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n") + if entry.score > max_score + 0.01: + max_score = entry.score + tol_level = 0 if oracle.finish or len(current_entries) >= config["num_gens_per_iter"]: break @@ -112,11 +119,18 @@ def optimize( break current_entries = list(np.unique(current_entries))[::-1] + tol_level += 1 num_iter += 1 if oracle.finish: break - if config["strategy"] == "rej-sample": + # print("tol_level", tol_level) + if tol_level >= 5: + num_to_dump = len(molecule_pool) // 2 + molecule_pool.random_dump(num_to_dump) + file.write(f"Dump {num_to_dump} random elements from pool, num pool mols {len(molecule_pool)}\n") + tol_level = 0 + if config["strategy"] == "rej-sample" and tol_level >= 5: round_entries.extend(current_entries) round_entries = list(np.unique(round_entries))[::-1] top_k = int(len(round_entries) * config["rej_sample_config"]["rej_perc"]) @@ -132,13 +146,14 @@ def optimize( for entry in training_entries ] }) - # train_dataset.shuffle(seed=42) + train_dataset.shuffle(seed=42) config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"] supervised_fine_tune(model, tokenizer, train_dataset, config["rej_sample_config"]) round_entries = [] gc.collect() torch.cuda.empty_cache() + # diversity_score = 1 / (1 + math.log(1 + repeated_max_score) / math.log(10)) molecule_pool.add(current_entries) file.write("Molecule pool\n") for i, mol in enumerate(molecule_pool.molecule_entries): diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py index 4eaaa89..41de28c 100644 --- a/chemlactica/mol_opt/tunning.py +++ b/chemlactica/mol_opt/tunning.py @@ -18,6 +18,7 @@ def supervised_fine_tune( dataloader_drop_last=False, dataloader_pin_memory=True, dataloader_num_workers=config["dataloader_num_workers"], + gradient_accumulation_steps=config["gradient_accumulation_steps"], logging_steps=1 ) optimizer = torch.optim.AdamW( diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 390dba8..3416b43 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -81,17 +81,24 @@ def __init__(self, size): self.size = size self.molecule_entries: List[MoleculeEntry] = [] - def add(self, entries: List[MoleculeEntry]): + def random_dump(self, num): + for _ in range(num): + rand_ind = random.randint(0, num - 1) + self.molecule_entries.pop(rand_ind) + print(f"Dump {num} random elements from pool, num pool mols {len(self)}") + + def add(self, entries: List[MoleculeEntry], diversity_score=1.0): assert type(entries) == list self.molecule_entries.extend(entries) self.molecule_entries.sort(reverse=True) + # print(f"Updating with div_score {diversity_score:.4f}") # remove doublicates new_molecule_entries = [] for mol in self.molecule_entries: insert = True for m in new_molecule_entries: - if mol == m or tanimoto_dist_func(mol.fingerprint, m.fingerprint) > 1.0: + if mol == m or tanimoto_dist_func(mol.fingerprint, m.fingerprint) > diversity_score: insert = False break if insert: @@ -101,4 +108,7 @@ def add(self, entries: List[MoleculeEntry]): def random_subset(self, subset_size): rand_inds = np.random.permutation(min(len(self.molecule_entries), subset_size)) - return [self.molecule_entries[i] for i in rand_inds] \ No newline at end of file + return [self.molecule_entries[i] for i in rand_inds] + + def __len__(self): + return len(self.molecule_entries) \ No newline at end of file From 40f7dde8cc663beb94055948630d69c82063d53b Mon Sep 17 00:00:00 2001 From: tigranfah Date: Mon, 29 Apr 2024 17:11:01 +0400 Subject: [PATCH 06/45] add pool dump --- chemlactica/mol_opt/optimization.py | 65 +++++------------------------ 1 file changed, 10 insertions(+), 55 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 75ff9cf..2e7e6e6 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -3,19 +3,22 @@ import multiprocessing import gc import math +import tqdm +import random import numpy as np -from collections import namedtuple from chemlactica.mol_opt.utils import MoleculeEntry, MoleculePool, generate_random_number from chemlactica.mol_opt.tunning import supervised_fine_tune -def create_optimization_prompts(num_prompts, molecule_pool, max_similars_in_prompt: int, sim_range): +def create_optimization_prompts(num_prompts, molecule_pool, max_similars_in_prompt: int, sim_range, post_processor=None): prompts = [] for i in range(num_prompts): similars_in_prompt = molecule_pool.random_subset(max_similars_in_prompt) prompt = "" for mol in similars_in_prompt: prompt += f"[SIMILAR]{mol.smiles} {generate_random_number(sim_range[0], sim_range[1]):.2f}[/SIMILAR]" + if post_processor: + prompt = post_processor(prompt) prompt += "[START_SMILES]" prompts.append(prompt) return prompts @@ -84,7 +87,7 @@ def optimize( prompts = create_optimization_prompts( config["num_gens_per_iter"], molecule_pool, max_similars_in_prompt=config["max_similars_in_prompt"], - sim_range=config["sim_range"] + sim_range=config["sim_range"], post_processor=config.get("prompts_post_processor") ) output_texts = [] generation_batch_size = 64 @@ -125,12 +128,12 @@ def optimize( break # print("tol_level", tol_level) - if tol_level >= 5: - num_to_dump = len(molecule_pool) // 2 + if config["strategy"] == "pool-dump" and tol_level >= 5 and max_score < 0.99: + num_to_dump = int(len(molecule_pool) * config["pool_dump_config"]["dump_perc"]) molecule_pool.random_dump(num_to_dump) file.write(f"Dump {num_to_dump} random elements from pool, num pool mols {len(molecule_pool)}\n") tol_level = 0 - if config["strategy"] == "rej-sample" and tol_level >= 5: + if config["strategy"] == "rej-sample": round_entries.extend(current_entries) round_entries = list(np.unique(round_entries))[::-1] top_k = int(len(round_entries) * config["rej_sample_config"]["rej_perc"]) @@ -157,52 +160,4 @@ def optimize( molecule_pool.add(current_entries) file.write("Molecule pool\n") for i, mol in enumerate(molecule_pool.molecule_entries): - file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") - - -# def optimize_reinvent( -# model, prior_model, -# tokenizer, oracle, -# config -# ): -# print("molecule pool size", config["molecule_pool_size"]) -# print("molecule generation arguments", config["mol_gen_kwargs"]) -# molecule_pool = MoleculePool(config["molecule_pool_size"]) - -# num_iter = 1 -# while True: -# generated_entries = [] -# while len(generated_entries) < config["num_gens_per_iter"]: -# prompts = create_optimization_prompts( -# config["num_gens_per_iter"], molecule_pool, -# max_similars_in_prompt=config["max_similars_in_prompt"], -# sim_range=config["sim_range"] -# ) -# output_texts = [] -# generation_batch_size = 200 -# for i in range(0, len(prompts), generation_batch_size): -# prompt_batch = prompts[i: min(len(prompts), i + generation_batch_size)] -# data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device) -# del data["token_type_ids"] -# output = model.generate( -# **data, -# **config["mol_gen_kwargs"] -# ) -# gc.collect() -# torch.cuda.empty_cache() -# output_texts.extend(tokenizer.batch_decode(output)) - -# with multiprocessing.Pool(processes=config["num_processes"]) as pol: -# for entry in pol.map(create_molecule_entry, output_texts) if entry]: -# entry.score = oracle(entry.smiles) -# generated_entries.append(entry) -# if oracle.finish or len(generated_entries) >= config["num_gens_per_iter"]: -# break - -# if oracle.finish: -# break - -# num_iter += 1 -# if oracle.finish: -# break -# molecule_pool.add(generated_entries) \ No newline at end of file + file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") \ No newline at end of file From 7559e5b9e69223d0a2faacb6f1b4d31f944799a7 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Fri, 3 May 2024 19:47:55 +0400 Subject: [PATCH 07/45] add feature --- chemlactica/mol_opt/optimization.py | 16 ++++++++++------ chemlactica/mol_opt/utils.py | 2 +- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 2e7e6e6..83a5324 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -74,7 +74,7 @@ def optimize( # print("molecule generation arguments", config["generation_config"]) molecule_pool = MoleculePool(config["molecule_pool_size"]) - if config["strategy"] == "rej-sample": + if "rej-sample" in config["strategy"]: round_entries = [] max_score = 0 @@ -108,8 +108,12 @@ def optimize( with multiprocessing.Pool(processes=config["num_processes"]) as pol: for i, entry in enumerate(pol.map(create_molecule_entry, output_texts)): if entry: - entry.score = oracle(entry.smiles) - entry.additional_properties["prompt"] = prompts[i] + if getattr(oracle, 'takes_entry', False): + oracle_score = oracle(entry) + else: + oracle_score = oracle(entry.smiles) + entry.score = oracle_score + entry.add_props["prompt"] = prompts[i] current_entries.append(entry) file.write(f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n") if entry.score > max_score + 0.01: @@ -128,12 +132,12 @@ def optimize( break # print("tol_level", tol_level) - if config["strategy"] == "pool-dump" and tol_level >= 5 and max_score < 0.99: + if "pool-dump" in config["strategy"] and tol_level >= 5 and max_score < 0.99: num_to_dump = int(len(molecule_pool) * config["pool_dump_config"]["dump_perc"]) molecule_pool.random_dump(num_to_dump) file.write(f"Dump {num_to_dump} random elements from pool, num pool mols {len(molecule_pool)}\n") tol_level = 0 - if config["strategy"] == "rej-sample": + if "rej-sample" in config["strategy"]: round_entries.extend(current_entries) round_entries = list(np.unique(round_entries))[::-1] top_k = int(len(round_entries) * config["rej_sample_config"]["rej_perc"]) @@ -145,7 +149,7 @@ def optimize( file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") train_dataset = Dataset.from_dict({ "sample": [ - f"{entry.additional_properties['prompt']}{entry.smiles}[END_SMILES]" + f"{entry.add_props['prompt']}{entry.smiles}[END_SMILES]" for entry in training_entries ] }) diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 3416b43..e1915e8 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -56,7 +56,7 @@ def __init__(self, smiles, score=None, score_estimate=None, **kwargs): self.fingerprint = get_morgan_fingerprint(self.mol) self.score = score self.score_estimate = score_estimate - self.additional_properties = kwargs + self.add_props = kwargs def __eq__(self, other): return self.smiles == other.smiles From a0232818dea07943d76b1c7b10be24f2f4a9487e Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Tue, 7 May 2024 18:24:47 +0400 Subject: [PATCH 08/45] don't enter scoring code if oracle has exceeded budget --- chemlactica/mol_opt/optimization.py | 108 +++++++++++++++++----------- 1 file changed, 67 insertions(+), 41 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 83a5324..2934ad5 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -2,21 +2,28 @@ from datasets import Dataset import multiprocessing import gc -import math -import tqdm -import random import numpy as np -from chemlactica.mol_opt.utils import MoleculeEntry, MoleculePool, generate_random_number +from chemlactica.mol_opt.utils import ( + MoleculeEntry, + MoleculePool, + generate_random_number, +) from chemlactica.mol_opt.tunning import supervised_fine_tune -def create_optimization_prompts(num_prompts, molecule_pool, max_similars_in_prompt: int, sim_range, post_processor=None): +def create_optimization_prompts( + num_prompts, + molecule_pool, + max_similars_in_prompt: int, + sim_range, + post_processor=None, +): prompts = [] for i in range(num_prompts): similars_in_prompt = molecule_pool.random_subset(max_similars_in_prompt) prompt = "" for mol in similars_in_prompt: - prompt += f"[SIMILAR]{mol.smiles} {generate_random_number(sim_range[0], sim_range[1]):.2f}[/SIMILAR]" + prompt += f"[SIMILAR]{mol.smiles} {generate_random_number(sim_range[0], sim_range[1]):.2f}[/SIMILAR]" # noqa if post_processor: prompt = post_processor(prompt) prompt += "[START_SMILES]" @@ -31,7 +38,7 @@ def create_molecule_entry(output_text): if start_ind == -1 or end_ind == -1: return None - generated_smiles = output_text[start_ind+len(start_smiles_tag):end_ind] + generated_smiles = output_text[start_ind + len(start_smiles_tag) : end_ind] # noqa for similar in output_text.split("[SIMILAR]")[1:-1]: similar_smiles = similar.split(" ")[0] if generated_smiles == similar_smiles: @@ -40,7 +47,7 @@ def create_molecule_entry(output_text): return MoleculeEntry( smiles=generated_smiles, ) - except: + except Exception: return None @@ -49,26 +56,22 @@ def query_molecule_properties(model, tokenizer, smiles, property_tag, prop_pred_ prompts = [f"[START_SMILES]{smiles}[END_SMILES][{property_tag}]"] data = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) del data["token_type_ids"] - outputs = model.generate( - **data, - **prop_pred_kwargs - ) + outputs = model.generate(**data, **prop_pred_kwargs) predicted_property_values = [] output_texts = tokenizer.batch_decode(outputs) for output_text in output_texts: start_ind = output_text.find(property_start_tag) end_ind = output_text.find(property_end_tag) if start_ind != -1 and end_ind != -1: - predicted_property_values.append(output_text[start_ind+len(property_start_tag):end_ind]) + predicted_property_values.append( + output_text[start_ind + len(property_start_tag) : end_ind] # noqa + ) else: predicted_property_values.append(None) return predicted_property_values -def optimize( - model, tokenizer, - oracle, config - ): +def optimize(model, tokenizer, oracle, config): file = open(config["log_dir"], "w") print("config", config) # print("molecule generation arguments", config["generation_config"]) @@ -85,41 +88,52 @@ def optimize( current_entries = [] while len(current_entries) < config["num_gens_per_iter"]: prompts = create_optimization_prompts( - config["num_gens_per_iter"], molecule_pool, + config["num_gens_per_iter"], + molecule_pool, max_similars_in_prompt=config["max_similars_in_prompt"], - sim_range=config["sim_range"], post_processor=config.get("prompts_post_processor") + sim_range=config["sim_range"], + post_processor=config.get("prompts_post_processor"), ) output_texts = [] generation_batch_size = 64 for i in range(0, len(prompts), generation_batch_size): - prompt_batch = prompts[i: min(len(prompts), i + generation_batch_size)] - data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device) + prompt_batch = prompts[ + i : min(len(prompts), i + generation_batch_size) # noqa + ] # noqa + data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to( + model.device + ) del data["token_type_ids"] for key, value in data.items(): - data[key] = value[:, -2048 + config["generation_config"]["max_new_tokens"]:] - output = model.generate( - **data, - **config["generation_config"] - ) + data[key] = value[ + :, + -2048 + config["generation_config"]["max_new_tokens"] :, # noqa + ] + output = model.generate(**data, **config["generation_config"]) gc.collect() torch.cuda.empty_cache() output_texts.extend(tokenizer.batch_decode(output)) with multiprocessing.Pool(processes=config["num_processes"]) as pol: for i, entry in enumerate(pol.map(create_molecule_entry, output_texts)): - if entry: - if getattr(oracle, 'takes_entry', False): + if entry and not oracle.finish: + if getattr(oracle, "takes_entry", False): oracle_score = oracle(entry) else: oracle_score = oracle(entry.smiles) entry.score = oracle_score entry.add_props["prompt"] = prompts[i] current_entries.append(entry) - file.write(f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n") + file.write( + f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n" + ) if entry.score > max_score + 0.01: max_score = entry.score tol_level = 0 - if oracle.finish or len(current_entries) >= config["num_gens_per_iter"]: + if ( + oracle.finish + or len(current_entries) >= config["num_gens_per_iter"] + ): break if oracle.finish: @@ -133,29 +147,41 @@ def optimize( # print("tol_level", tol_level) if "pool-dump" in config["strategy"] and tol_level >= 5 and max_score < 0.99: - num_to_dump = int(len(molecule_pool) * config["pool_dump_config"]["dump_perc"]) + num_to_dump = int( + len(molecule_pool) * config["pool_dump_config"]["dump_perc"] + ) molecule_pool.random_dump(num_to_dump) - file.write(f"Dump {num_to_dump} random elements from pool, num pool mols {len(molecule_pool)}\n") + file.write( + f"Dump {num_to_dump} random elements from pool, \ + num pool mols {len(molecule_pool)}\n" + ) tol_level = 0 if "rej-sample" in config["strategy"]: round_entries.extend(current_entries) round_entries = list(np.unique(round_entries))[::-1] top_k = int(len(round_entries) * config["rej_sample_config"]["rej_perc"]) - if len(round_entries[:top_k]) >= config["rej_sample_config"]["num_samples_per_round"]: + if ( + len(round_entries[:top_k]) + >= config["rej_sample_config"]["num_samples_per_round"] + ): training_entries = round_entries[:top_k] print(f"Num of train examples {len(training_entries)}.") file.write("Training entries\n") for i, mol in enumerate(training_entries): file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") - train_dataset = Dataset.from_dict({ - "sample": [ - f"{entry.add_props['prompt']}{entry.smiles}[END_SMILES]" - for entry in training_entries - ] - }) + train_dataset = Dataset.from_dict( + { + "sample": [ + f"{entry.add_props['prompt']}{entry.smiles}[END_SMILES]" + for entry in training_entries + ] + } + ) train_dataset.shuffle(seed=42) config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"] - supervised_fine_tune(model, tokenizer, train_dataset, config["rej_sample_config"]) + supervised_fine_tune( + model, tokenizer, train_dataset, config["rej_sample_config"] + ) round_entries = [] gc.collect() torch.cuda.empty_cache() @@ -164,4 +190,4 @@ def optimize( molecule_pool.add(current_entries) file.write("Molecule pool\n") for i, mol in enumerate(molecule_pool.molecule_entries): - file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") \ No newline at end of file + file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") From 8ae3162c0972ac7b98240b427176028aced910c3 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Tue, 7 May 2024 18:27:20 +0400 Subject: [PATCH 09/45] add file generation utility --- chemlactica/mol_opt/utils.py | 60 +++++++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 18 deletions(-) diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index e1915e8..cfbe4f0 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -1,14 +1,15 @@ from typing import List +import datetime +import os import random from pathlib import Path import numpy as np import torch from rdkit import Chem, DataStructs, RDLogger from rdkit.Chem import AllChem, MACCSkeys -from rdkit.Chem.QED import qed # Disable RDKit logs -RDLogger.DisableLog('rdApp.*') +RDLogger.DisableLog("rdApp.*") def set_seed(seed_value): @@ -19,9 +20,13 @@ def set_seed(seed_value): torch.manual_seed(seed_value) -def get_short_name_for_ckpt_path(chpt_path: str, hash_len: int=6): +def get_short_name_for_ckpt_path(chpt_path: str, hash_len: int = 6): get_short_name_for_ckpt_path = Path(chpt_path) - return get_short_name_for_ckpt_path.parent.name[:hash_len] + '-' + get_short_name_for_ckpt_path.name.split("-")[-1] + return ( + get_short_name_for_ckpt_path.parent.name[:hash_len] + + "-" + + get_short_name_for_ckpt_path.name.split("-")[-1] + ) def get_morgan_fingerprint(mol): @@ -32,10 +37,10 @@ def get_maccs_fingerprint(mol): return MACCSkeys.GenMACCSKeys(mol) -def tanimoto_dist_func(fing1, fing2, fingerprint: str="morgan"): +def tanimoto_dist_func(fing1, fing2, fingerprint: str = "morgan"): return DataStructs.TanimotoSimilarity( - fing1 if fingerprint == 'morgan' else fing1, - fing2 if fingerprint == 'morgan' else fing2, + fing1 if fingerprint == "morgan" else fing1, + fing2 if fingerprint == "morgan" else fing2, ) @@ -49,7 +54,6 @@ def canonicalize(smiles): class MoleculeEntry: - def __init__(self, smiles, score=None, score_estimate=None, **kwargs): self.smiles = canonicalize(smiles) self.mol = Chem.MolFromSmiles(smiles) @@ -65,18 +69,19 @@ def __lt__(self, other): if self.score == other.score: return self.smiles < other.smiles return self.score < other.score - + def __str__(self): - return f"smiles: {self.smiles}, " \ - f"score: {round(self.score, 4) if self.score != None else 'none'}, " \ - f"score_estimate: {round(self.score_estimate, 4) if self.score_estimate != None else 'none'}" - + return ( + f"smiles: {self.smiles}, " + f"score: {round(self.score, 4) if self.score != None else 'none'}, " + f"score_estimate: {round(self.score_estimate, 4) if self.score_estimate != None else 'none'}" # noqa + ) + def __repr__(self): return str(self) class MoleculePool: - def __init__(self, size): self.size = size self.molecule_entries: List[MoleculeEntry] = [] @@ -98,17 +103,36 @@ def add(self, entries: List[MoleculeEntry], diversity_score=1.0): for mol in self.molecule_entries: insert = True for m in new_molecule_entries: - if mol == m or tanimoto_dist_func(mol.fingerprint, m.fingerprint) > diversity_score: + if ( + mol == m + or tanimoto_dist_func(mol.fingerprint, m.fingerprint) + > diversity_score + ): insert = False break if insert: new_molecule_entries.append(mol) - self.molecule_entries = new_molecule_entries[:min(len(new_molecule_entries), self.size)] + self.molecule_entries = new_molecule_entries[ + : min(len(new_molecule_entries), self.size) + ] def random_subset(self, subset_size): rand_inds = np.random.permutation(min(len(self.molecule_entries), subset_size)) return [self.molecule_entries[i] for i in rand_inds] - + def __len__(self): - return len(self.molecule_entries) \ No newline at end of file + return len(self.molecule_entries) + + +def make_output_files_base(input_path, results_dir, run_name, config): + formatted_date_time = datetime.datetime.now().strftime("%Y-%m-%d") + base = os.path.join(input_path, run_name, formatted_date_time) + os.makedirs(base, exist_ok=True) + v = 0 + strategy = "+".join(config["strategy"]) + while os.path.exists(os.path.join(base, f"{strategy}-{v}")): + v += 1 + output_dir = os.path.join(base, f"{strategy}-{v}") + os.makedirs(output_dir, exist_ok=True) + return output_dir From f2ee4681946551256466ca84e887a6f29c970502 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Fri, 10 May 2024 22:33:17 +0400 Subject: [PATCH 10/45] rej-sample-v2 --- chemlactica/mol_opt/optimization.py | 67 ++++++++++++++++++----------- chemlactica/mol_opt/tunning.py | 2 +- chemlactica/mol_opt/utils.py | 1 - 3 files changed, 43 insertions(+), 27 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 83a5324..8690870 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -6,35 +6,43 @@ import tqdm import random import numpy as np +from transformers import OPTForCausalLM from chemlactica.mol_opt.utils import MoleculeEntry, MoleculePool, generate_random_number from chemlactica.mol_opt.tunning import supervised_fine_tune -def create_optimization_prompts(num_prompts, molecule_pool, max_similars_in_prompt: int, sim_range, post_processor=None): +def create_optimization_prompts(num_prompts, molecule_pool, max_mols_in_prompt: int, post_processor=None): prompts = [] for i in range(num_prompts): - similars_in_prompt = molecule_pool.random_subset(max_similars_in_prompt) + similars_in_prompt = molecule_pool.random_subset(max_mols_in_prompt) prompt = "" + oracle_scores_of_mols_in_prompt = [] for mol in similars_in_prompt: - prompt += f"[SIMILAR]{mol.smiles} {generate_random_number(sim_range[0], sim_range[1]):.2f}[/SIMILAR]" + prompt += f"[ORACLE_SCORE]{mol.score:.2f}[/ORACLE_SCORE][START_SMILES]{mol.smiles}[END_SMILES]" + # prompt += f"[SIMILAR]{mol.smiles} {generate_random_number(0.8, 0.9)}[/SIMILAR]" + # prompt += f"[START_SMILES]{mol.smiles}[END_SMILES]" + oracle_scores_of_mols_in_prompt.append(mol.score) if post_processor: prompt = post_processor(prompt) - prompt += "[START_SMILES]" + q_0_9 = np.quantile(oracle_scores_of_mols_in_prompt, 0.9) if oracle_scores_of_mols_in_prompt else 0 + required_oracle_score = generate_random_number(q_0_9, 1.0) # TODO: change the hard coded 1.0 + prompt += f"[ORACLE_SCORE]{required_oracle_score:.2f}[/ORACLE_SCORE][START_SMILES]" + # prompt += f"[START_SMILES]" prompts.append(prompt) return prompts def create_molecule_entry(output_text): start_smiles_tag, end_smiles_tag = "[START_SMILES]", "[END_SMILES]" - start_ind = output_text.find(start_smiles_tag) - end_ind = output_text.find(end_smiles_tag) + start_ind = output_text.rfind(start_smiles_tag) + end_ind = output_text.rfind(end_smiles_tag) if start_ind == -1 or end_ind == -1: return None - generated_smiles = output_text[start_ind+len(start_smiles_tag):end_ind] - for similar in output_text.split("[SIMILAR]")[1:-1]: - similar_smiles = similar.split(" ")[0] - if generated_smiles == similar_smiles: + + for output in output_text.split(start_smiles_tag)[:-1]: + smiles_in_prompt = output.split(end_smiles_tag)[0] + if generated_smiles == smiles_in_prompt: return None try: return MoleculeEntry( @@ -75,26 +83,26 @@ def optimize( molecule_pool = MoleculePool(config["molecule_pool_size"]) if "rej-sample" in config["strategy"]: - round_entries = [] + round_entries = {} max_score = 0 tol_level = 0 - num_iter = 1 + num_iter = 0 while True: model.eval() current_entries = [] while len(current_entries) < config["num_gens_per_iter"]: prompts = create_optimization_prompts( config["num_gens_per_iter"], molecule_pool, - max_similars_in_prompt=config["max_similars_in_prompt"], - sim_range=config["sim_range"], post_processor=config.get("prompts_post_processor") + max_mols_in_prompt=config["max_mols_in_prompt"], + post_processor=config.get("prompts_post_processor") ) output_texts = [] - generation_batch_size = 64 - for i in range(0, len(prompts), generation_batch_size): - prompt_batch = prompts[i: min(len(prompts), i + generation_batch_size)] + for i in range(0, len(prompts), config["generation_batch_size"]): + prompt_batch = prompts[i: min(len(prompts), i + config["generation_batch_size"])] data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device) - del data["token_type_ids"] + if type(model) == OPTForCausalLM: + del data["token_type_ids"] for key, value in data.items(): data[key] = value[:, -2048 + config["generation_config"]["max_new_tokens"]:] output = model.generate( @@ -122,6 +130,7 @@ def optimize( if oracle.finish or len(current_entries) >= config["num_gens_per_iter"]: break + # print(num_iter, len(current_entries)) if oracle.finish: break @@ -132,31 +141,39 @@ def optimize( break # print("tol_level", tol_level) - if "pool-dump" in config["strategy"] and tol_level >= 5 and max_score < 0.99: + if "pool-dump" in config["strategy"] and tol_level >= 5: num_to_dump = int(len(molecule_pool) * config["pool_dump_config"]["dump_perc"]) molecule_pool.random_dump(num_to_dump) file.write(f"Dump {num_to_dump} random elements from pool, num pool mols {len(molecule_pool)}\n") tol_level = 0 if "rej-sample" in config["strategy"]: - round_entries.extend(current_entries) - round_entries = list(np.unique(round_entries))[::-1] + # round_entries.extend(current_entries) + # round_entries = list(np.unique(round_entries))[::-1] + for entry in current_entries: + round_entries[entry.smiles] = entry top_k = int(len(round_entries) * config["rej_sample_config"]["rej_perc"]) - if len(round_entries[:top_k]) >= config["rej_sample_config"]["num_samples_per_round"]: - training_entries = round_entries[:top_k] + # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: + if num_iter % 10 == 0: + training_entries = np.unique(molecule_pool.molecule_entries)[-top_k:] print(f"Num of train examples {len(training_entries)}.") file.write("Training entries\n") for i, mol in enumerate(training_entries): file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") + + def create_training_sample(entry): + sample = entry.add_props["prompt"] + return sample + f"[START_SMILES]{entry.smiles}[END_SMILES]" + train_dataset = Dataset.from_dict({ "sample": [ - f"{entry.add_props['prompt']}{entry.smiles}[END_SMILES]" + create_training_sample(entry) for entry in training_entries ] }) train_dataset.shuffle(seed=42) config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"] supervised_fine_tune(model, tokenizer, train_dataset, config["rej_sample_config"]) - round_entries = [] + round_entries = {} gc.collect() torch.cuda.empty_cache() diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py index 41de28c..7ff56e5 100644 --- a/chemlactica/mol_opt/tunning.py +++ b/chemlactica/mol_opt/tunning.py @@ -45,7 +45,7 @@ def supervised_fine_tune( packing=config["packing"], tokenizer=tokenizer, max_seq_length=config["max_seq_length"], - data_collator=collator, + # data_collator=collator, optimizers=[optimizer, lr_scheduler] ) trainer.train() diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index e1915e8..7f191fd 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -5,7 +5,6 @@ import torch from rdkit import Chem, DataStructs, RDLogger from rdkit.Chem import AllChem, MACCSkeys -from rdkit.Chem.QED import qed # Disable RDKit logs RDLogger.DisableLog('rdApp.*') From 44cfb4b3605f951354b65b32266cb71eaa935332 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Mon, 13 May 2024 17:33:37 +0400 Subject: [PATCH 11/45] rej-sample-v2.1 --- chemlactica/mol_opt/optimization.py | 105 +++++++++++++++------------- chemlactica/mol_opt/tunning.py | 38 ++++++++-- 2 files changed, 90 insertions(+), 53 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 8690870..2a3235b 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -11,23 +11,28 @@ from chemlactica.mol_opt.tunning import supervised_fine_tune -def create_optimization_prompts(num_prompts, molecule_pool, max_mols_in_prompt: int, post_processor=None): +def create_optimization_prompts(num_prompts, molecule_pool, max_mols_in_prompt: int, strategy: str, eos_token: str, post_processor=None): prompts = [] for i in range(num_prompts): similars_in_prompt = molecule_pool.random_subset(max_mols_in_prompt) - prompt = "" + prompt = eos_token oracle_scores_of_mols_in_prompt = [] for mol in similars_in_prompt: - prompt += f"[ORACLE_SCORE]{mol.score:.2f}[/ORACLE_SCORE][START_SMILES]{mol.smiles}[END_SMILES]" - # prompt += f"[SIMILAR]{mol.smiles} {generate_random_number(0.8, 0.9)}[/SIMILAR]" + if "default" in strategy: + prompt += f"[SIMILAR]{mol.smiles} {generate_random_number(0.8, 0.9):.2f}[/SIMILAR]" + elif "rej-sample" in strategy: + prompt += f"[ORACLE_SCORE]{mol.score:.2f}[/ORACLE_SCORE][START_SMILES]{mol.smiles}[END_SMILES]" # prompt += f"[START_SMILES]{mol.smiles}[END_SMILES]" oracle_scores_of_mols_in_prompt.append(mol.score) if post_processor: prompt = post_processor(prompt) q_0_9 = np.quantile(oracle_scores_of_mols_in_prompt, 0.9) if oracle_scores_of_mols_in_prompt else 0 required_oracle_score = generate_random_number(q_0_9, 1.0) # TODO: change the hard coded 1.0 - prompt += f"[ORACLE_SCORE]{required_oracle_score:.2f}[/ORACLE_SCORE][START_SMILES]" - # prompt += f"[START_SMILES]" + if "default" in strategy: + prompt += f"[START_SMILES]" + elif "rej-sample" in strategy: + prompt += f"[ORACLE_SCORE]{required_oracle_score:.2f}[/ORACLE_SCORE][START_SMILES]" + prompts.append(prompt) return prompts @@ -52,25 +57,25 @@ def create_molecule_entry(output_text): return None -def query_molecule_properties(model, tokenizer, smiles, property_tag, prop_pred_kwargs): - property_start_tag, property_end_tag = f"[{property_tag}]", f"[/{property_tag}]" - prompts = [f"[START_SMILES]{smiles}[END_SMILES][{property_tag}]"] - data = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - del data["token_type_ids"] - outputs = model.generate( - **data, - **prop_pred_kwargs - ) - predicted_property_values = [] - output_texts = tokenizer.batch_decode(outputs) - for output_text in output_texts: - start_ind = output_text.find(property_start_tag) - end_ind = output_text.find(property_end_tag) - if start_ind != -1 and end_ind != -1: - predicted_property_values.append(output_text[start_ind+len(property_start_tag):end_ind]) - else: - predicted_property_values.append(None) - return predicted_property_values +# def query_molecule_properties(model, tokenizer, smiles, property_tag, prop_pred_kwargs): +# property_start_tag, property_end_tag = f"[{property_tag}]", f"[/{property_tag}]" +# prompts = [f"[START_SMILES]{smiles}[END_SMILES][{property_tag}]"] +# data = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) +# del data["token_type_ids"] +# outputs = model.generate( +# **data, +# **prop_pred_kwargs +# ) +# predicted_property_values = [] +# output_texts = tokenizer.batch_decode(outputs) +# for output_text in output_texts: +# start_ind = output_text.find(property_start_tag) +# end_ind = output_text.find(property_end_tag) +# if start_ind != -1 and end_ind != -1: +# predicted_property_values.append(output_text[start_ind+len(property_start_tag):end_ind]) +# else: +# predicted_property_values.append(None) +# return predicted_property_values def optimize( @@ -83,7 +88,7 @@ def optimize( molecule_pool = MoleculePool(config["molecule_pool_size"]) if "rej-sample" in config["strategy"]: - round_entries = {} + all_entries = {} max_score = 0 tol_level = 0 @@ -93,8 +98,11 @@ def optimize( current_entries = [] while len(current_entries) < config["num_gens_per_iter"]: prompts = create_optimization_prompts( - config["num_gens_per_iter"], molecule_pool, + config["num_gens_per_iter"], + molecule_pool, max_mols_in_prompt=config["max_mols_in_prompt"], + eos_token=config["eos_token"], + strategy=config["strategy"], post_processor=config.get("prompts_post_processor") ) output_texts = [] @@ -134,27 +142,27 @@ def optimize( if oracle.finish: break - current_entries = list(np.unique(current_entries))[::-1] - tol_level += 1 - num_iter += 1 if oracle.finish: break + current_entries = list(np.unique(current_entries))[::-1] + initial_num_iter = num_iter + num_iter = len(oracle.mol_buffer) // config["num_gens_per_iter"] + print("num_iter: ", num_iter) + + # diversity_score = 1 / (1 + math.log(1 + repeated_max_score) / math.log(10)) + molecule_pool.add(current_entries) + file.write("Molecule pool\n") + for i, mol in enumerate(molecule_pool.molecule_entries): + file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") - # print("tol_level", tol_level) - if "pool-dump" in config["strategy"] and tol_level >= 5: - num_to_dump = int(len(molecule_pool) * config["pool_dump_config"]["dump_perc"]) - molecule_pool.random_dump(num_to_dump) - file.write(f"Dump {num_to_dump} random elements from pool, num pool mols {len(molecule_pool)}\n") - tol_level = 0 if "rej-sample" in config["strategy"]: # round_entries.extend(current_entries) # round_entries = list(np.unique(round_entries))[::-1] - for entry in current_entries: - round_entries[entry.smiles] = entry - top_k = int(len(round_entries) * config["rej_sample_config"]["rej_perc"]) + # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: - if num_iter % 10 == 0: - training_entries = np.unique(molecule_pool.molecule_entries)[-top_k:] + # if num_iter > initial_num_iter and num_iter % 3 == 0: + if tol_level >= 2: + training_entries = molecule_pool.molecule_entries print(f"Num of train examples {len(training_entries)}.") file.write("Training entries\n") for i, mol in enumerate(training_entries): @@ -162,7 +170,7 @@ def optimize( def create_training_sample(entry): sample = entry.add_props["prompt"] - return sample + f"[START_SMILES]{entry.smiles}[END_SMILES]" + return sample + f"[START_SMILES]{entry.smiles}[END_SMILES]" train_dataset = Dataset.from_dict({ "sample": [ @@ -173,12 +181,11 @@ def create_training_sample(entry): train_dataset.shuffle(seed=42) config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"] supervised_fine_tune(model, tokenizer, train_dataset, config["rej_sample_config"]) - round_entries = {} gc.collect() torch.cuda.empty_cache() - - # diversity_score = 1 / (1 + math.log(1 + repeated_max_score) / math.log(10)) - molecule_pool.add(current_entries) - file.write("Molecule pool\n") - for i, mol in enumerate(molecule_pool.molecule_entries): - file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") \ No newline at end of file + tol_level = 0 + if "pool-dump" in config["strategy"] and tol_level >= 10: + num_to_dump = int(len(molecule_pool) * config["pool_dump_config"]["dump_perc"]) + molecule_pool.random_dump(num_to_dump) + file.write(f"Dump {num_to_dump} random elements from pool, num pool mols {len(molecule_pool)}\n") + tol_level = 0 \ No newline at end of file diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py index 7ff56e5..a4c524c 100644 --- a/chemlactica/mol_opt/tunning.py +++ b/chemlactica/mol_opt/tunning.py @@ -1,7 +1,34 @@ +from transformers.trainer_callback import TrainerControl, TrainerState from trl import SFTTrainer, DataCollatorForCompletionOnlyLM -from transformers import TrainingArguments, get_polynomial_decay_schedule_with_warmup +from transformers import TrainingArguments, get_polynomial_decay_schedule_with_warmup, TrainerCallback from torch.optim.lr_scheduler import ConstantLR import torch +import math + + +class CustomSFTTrainer(SFTTrainer): + + def __init__(self, *args, patience, toll, **kwargs): + super().__init__(*args, **kwargs) + self.patience = patience + self.initial_pat = patience + self.toll = toll + self.best_loss = math.inf + + def log(self, logs) -> None: + if logs.get("loss"): + curr_loss = logs["loss"] + if curr_loss > self.best_loss - self.toll: + self.patience -= 1 + print(f"loss did not improve, patience {self.patience}") + else: + print("loss improved") + self.best_loss = curr_loss + self.patience = self.initial_pat + if self.patience == 0: + print("The loss does not improve, stop training.") + self.control.should_training_stop = True + return super().log(logs) def supervised_fine_tune( @@ -19,7 +46,8 @@ def supervised_fine_tune( dataloader_pin_memory=True, dataloader_num_workers=config["dataloader_num_workers"], gradient_accumulation_steps=config["gradient_accumulation_steps"], - logging_steps=1 + logging_steps=1, + metric_for_best_model="loss", ) optimizer = torch.optim.AdamW( model.parameters(), @@ -37,7 +65,7 @@ def supervised_fine_tune( collator = DataCollatorForCompletionOnlyLM( config["response_template"], tokenizer=tokenizer ) - trainer = SFTTrainer( + trainer = CustomSFTTrainer( model=model, train_dataset=train_dataset, formatting_func=config["formatting_func"], @@ -46,6 +74,8 @@ def supervised_fine_tune( tokenizer=tokenizer, max_seq_length=config["max_seq_length"], # data_collator=collator, - optimizers=[optimizer, lr_scheduler] + optimizers=[optimizer, lr_scheduler], + patience=2, + toll=0.0001 ) trainer.train() From 898f5c10eb79f3cadf668fdedeca24ecead65487 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Tue, 14 May 2024 11:39:31 +0400 Subject: [PATCH 12/45] refine --- chemlactica/mol_opt/optimization.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 2a3235b..eb38083 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -160,8 +160,7 @@ def optimize( # round_entries = list(np.unique(round_entries))[::-1] # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: - # if num_iter > initial_num_iter and num_iter % 3 == 0: - if tol_level >= 2: + if tol_level >= 3 and num_iter > initial_num_iter: training_entries = molecule_pool.molecule_entries print(f"Num of train examples {len(training_entries)}.") file.write("Training entries\n") From a68b4f252123ecc8d4fa4fcc3278eace738c0478 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Fri, 10 May 2024 22:33:17 +0400 Subject: [PATCH 13/45] rej-sample-v2 --- chemlactica/mol_opt/optimization.py | 221 ++++++++++++++-------------- chemlactica/mol_opt/tunning.py | 2 +- 2 files changed, 110 insertions(+), 113 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 2934ad5..eb38083 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -2,87 +2,97 @@ from datasets import Dataset import multiprocessing import gc +import math +import tqdm +import random import numpy as np -from chemlactica.mol_opt.utils import ( - MoleculeEntry, - MoleculePool, - generate_random_number, -) +from transformers import OPTForCausalLM +from chemlactica.mol_opt.utils import MoleculeEntry, MoleculePool, generate_random_number from chemlactica.mol_opt.tunning import supervised_fine_tune -def create_optimization_prompts( - num_prompts, - molecule_pool, - max_similars_in_prompt: int, - sim_range, - post_processor=None, -): +def create_optimization_prompts(num_prompts, molecule_pool, max_mols_in_prompt: int, strategy: str, eos_token: str, post_processor=None): prompts = [] for i in range(num_prompts): - similars_in_prompt = molecule_pool.random_subset(max_similars_in_prompt) - prompt = "" + similars_in_prompt = molecule_pool.random_subset(max_mols_in_prompt) + prompt = eos_token + oracle_scores_of_mols_in_prompt = [] for mol in similars_in_prompt: - prompt += f"[SIMILAR]{mol.smiles} {generate_random_number(sim_range[0], sim_range[1]):.2f}[/SIMILAR]" # noqa + if "default" in strategy: + prompt += f"[SIMILAR]{mol.smiles} {generate_random_number(0.8, 0.9):.2f}[/SIMILAR]" + elif "rej-sample" in strategy: + prompt += f"[ORACLE_SCORE]{mol.score:.2f}[/ORACLE_SCORE][START_SMILES]{mol.smiles}[END_SMILES]" + # prompt += f"[START_SMILES]{mol.smiles}[END_SMILES]" + oracle_scores_of_mols_in_prompt.append(mol.score) if post_processor: prompt = post_processor(prompt) - prompt += "[START_SMILES]" + q_0_9 = np.quantile(oracle_scores_of_mols_in_prompt, 0.9) if oracle_scores_of_mols_in_prompt else 0 + required_oracle_score = generate_random_number(q_0_9, 1.0) # TODO: change the hard coded 1.0 + if "default" in strategy: + prompt += f"[START_SMILES]" + elif "rej-sample" in strategy: + prompt += f"[ORACLE_SCORE]{required_oracle_score:.2f}[/ORACLE_SCORE][START_SMILES]" + prompts.append(prompt) return prompts def create_molecule_entry(output_text): start_smiles_tag, end_smiles_tag = "[START_SMILES]", "[END_SMILES]" - start_ind = output_text.find(start_smiles_tag) - end_ind = output_text.find(end_smiles_tag) + start_ind = output_text.rfind(start_smiles_tag) + end_ind = output_text.rfind(end_smiles_tag) if start_ind == -1 or end_ind == -1: return None + generated_smiles = output_text[start_ind+len(start_smiles_tag):end_ind] - generated_smiles = output_text[start_ind + len(start_smiles_tag) : end_ind] # noqa - for similar in output_text.split("[SIMILAR]")[1:-1]: - similar_smiles = similar.split(" ")[0] - if generated_smiles == similar_smiles: + for output in output_text.split(start_smiles_tag)[:-1]: + smiles_in_prompt = output.split(end_smiles_tag)[0] + if generated_smiles == smiles_in_prompt: return None try: return MoleculeEntry( smiles=generated_smiles, ) - except Exception: + except: return None -def query_molecule_properties(model, tokenizer, smiles, property_tag, prop_pred_kwargs): - property_start_tag, property_end_tag = f"[{property_tag}]", f"[/{property_tag}]" - prompts = [f"[START_SMILES]{smiles}[END_SMILES][{property_tag}]"] - data = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - del data["token_type_ids"] - outputs = model.generate(**data, **prop_pred_kwargs) - predicted_property_values = [] - output_texts = tokenizer.batch_decode(outputs) - for output_text in output_texts: - start_ind = output_text.find(property_start_tag) - end_ind = output_text.find(property_end_tag) - if start_ind != -1 and end_ind != -1: - predicted_property_values.append( - output_text[start_ind + len(property_start_tag) : end_ind] # noqa - ) - else: - predicted_property_values.append(None) - return predicted_property_values - - -def optimize(model, tokenizer, oracle, config): +# def query_molecule_properties(model, tokenizer, smiles, property_tag, prop_pred_kwargs): +# property_start_tag, property_end_tag = f"[{property_tag}]", f"[/{property_tag}]" +# prompts = [f"[START_SMILES]{smiles}[END_SMILES][{property_tag}]"] +# data = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) +# del data["token_type_ids"] +# outputs = model.generate( +# **data, +# **prop_pred_kwargs +# ) +# predicted_property_values = [] +# output_texts = tokenizer.batch_decode(outputs) +# for output_text in output_texts: +# start_ind = output_text.find(property_start_tag) +# end_ind = output_text.find(property_end_tag) +# if start_ind != -1 and end_ind != -1: +# predicted_property_values.append(output_text[start_ind+len(property_start_tag):end_ind]) +# else: +# predicted_property_values.append(None) +# return predicted_property_values + + +def optimize( + model, tokenizer, + oracle, config + ): file = open(config["log_dir"], "w") print("config", config) # print("molecule generation arguments", config["generation_config"]) molecule_pool = MoleculePool(config["molecule_pool_size"]) if "rej-sample" in config["strategy"]: - round_entries = [] + all_entries = {} max_score = 0 tol_level = 0 - num_iter = 1 + num_iter = 0 while True: model.eval() current_entries = [] @@ -90,104 +100,91 @@ def optimize(model, tokenizer, oracle, config): prompts = create_optimization_prompts( config["num_gens_per_iter"], molecule_pool, - max_similars_in_prompt=config["max_similars_in_prompt"], - sim_range=config["sim_range"], - post_processor=config.get("prompts_post_processor"), + max_mols_in_prompt=config["max_mols_in_prompt"], + eos_token=config["eos_token"], + strategy=config["strategy"], + post_processor=config.get("prompts_post_processor") ) output_texts = [] - generation_batch_size = 64 - for i in range(0, len(prompts), generation_batch_size): - prompt_batch = prompts[ - i : min(len(prompts), i + generation_batch_size) # noqa - ] # noqa - data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to( - model.device - ) - del data["token_type_ids"] + for i in range(0, len(prompts), config["generation_batch_size"]): + prompt_batch = prompts[i: min(len(prompts), i + config["generation_batch_size"])] + data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device) + if type(model) == OPTForCausalLM: + del data["token_type_ids"] for key, value in data.items(): - data[key] = value[ - :, - -2048 + config["generation_config"]["max_new_tokens"] :, # noqa - ] - output = model.generate(**data, **config["generation_config"]) + data[key] = value[:, -2048 + config["generation_config"]["max_new_tokens"]:] + output = model.generate( + **data, + **config["generation_config"] + ) gc.collect() torch.cuda.empty_cache() output_texts.extend(tokenizer.batch_decode(output)) with multiprocessing.Pool(processes=config["num_processes"]) as pol: for i, entry in enumerate(pol.map(create_molecule_entry, output_texts)): - if entry and not oracle.finish: - if getattr(oracle, "takes_entry", False): + if entry: + if getattr(oracle, 'takes_entry', False): oracle_score = oracle(entry) else: oracle_score = oracle(entry.smiles) entry.score = oracle_score entry.add_props["prompt"] = prompts[i] current_entries.append(entry) - file.write( - f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n" - ) + file.write(f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n") if entry.score > max_score + 0.01: max_score = entry.score tol_level = 0 - if ( - oracle.finish - or len(current_entries) >= config["num_gens_per_iter"] - ): + if oracle.finish or len(current_entries) >= config["num_gens_per_iter"]: break + # print(num_iter, len(current_entries)) if oracle.finish: break - current_entries = list(np.unique(current_entries))[::-1] - tol_level += 1 - num_iter += 1 if oracle.finish: break + current_entries = list(np.unique(current_entries))[::-1] + initial_num_iter = num_iter + num_iter = len(oracle.mol_buffer) // config["num_gens_per_iter"] + print("num_iter: ", num_iter) + + # diversity_score = 1 / (1 + math.log(1 + repeated_max_score) / math.log(10)) + molecule_pool.add(current_entries) + file.write("Molecule pool\n") + for i, mol in enumerate(molecule_pool.molecule_entries): + file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") - # print("tol_level", tol_level) - if "pool-dump" in config["strategy"] and tol_level >= 5 and max_score < 0.99: - num_to_dump = int( - len(molecule_pool) * config["pool_dump_config"]["dump_perc"] - ) - molecule_pool.random_dump(num_to_dump) - file.write( - f"Dump {num_to_dump} random elements from pool, \ - num pool mols {len(molecule_pool)}\n" - ) - tol_level = 0 if "rej-sample" in config["strategy"]: - round_entries.extend(current_entries) - round_entries = list(np.unique(round_entries))[::-1] - top_k = int(len(round_entries) * config["rej_sample_config"]["rej_perc"]) - if ( - len(round_entries[:top_k]) - >= config["rej_sample_config"]["num_samples_per_round"] - ): - training_entries = round_entries[:top_k] + # round_entries.extend(current_entries) + # round_entries = list(np.unique(round_entries))[::-1] + # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) + # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: + if tol_level >= 3 and num_iter > initial_num_iter: + training_entries = molecule_pool.molecule_entries print(f"Num of train examples {len(training_entries)}.") file.write("Training entries\n") for i, mol in enumerate(training_entries): file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") - train_dataset = Dataset.from_dict( - { - "sample": [ - f"{entry.add_props['prompt']}{entry.smiles}[END_SMILES]" - for entry in training_entries - ] - } - ) + + def create_training_sample(entry): + sample = entry.add_props["prompt"] + return sample + f"[START_SMILES]{entry.smiles}[END_SMILES]" + + train_dataset = Dataset.from_dict({ + "sample": [ + create_training_sample(entry) + for entry in training_entries + ] + }) train_dataset.shuffle(seed=42) config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"] - supervised_fine_tune( - model, tokenizer, train_dataset, config["rej_sample_config"] - ) - round_entries = [] + supervised_fine_tune(model, tokenizer, train_dataset, config["rej_sample_config"]) gc.collect() torch.cuda.empty_cache() - - # diversity_score = 1 / (1 + math.log(1 + repeated_max_score) / math.log(10)) - molecule_pool.add(current_entries) - file.write("Molecule pool\n") - for i, mol in enumerate(molecule_pool.molecule_entries): - file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") + tol_level = 0 + if "pool-dump" in config["strategy"] and tol_level >= 10: + num_to_dump = int(len(molecule_pool) * config["pool_dump_config"]["dump_perc"]) + molecule_pool.random_dump(num_to_dump) + file.write(f"Dump {num_to_dump} random elements from pool, num pool mols {len(molecule_pool)}\n") + tol_level = 0 \ No newline at end of file diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py index 41de28c..7ff56e5 100644 --- a/chemlactica/mol_opt/tunning.py +++ b/chemlactica/mol_opt/tunning.py @@ -45,7 +45,7 @@ def supervised_fine_tune( packing=config["packing"], tokenizer=tokenizer, max_seq_length=config["max_seq_length"], - data_collator=collator, + # data_collator=collator, optimizers=[optimizer, lr_scheduler] ) trainer.train() From f5e1ca23ee29d156ecd98d53937e57c03be1b27b Mon Sep 17 00:00:00 2001 From: tigranfah Date: Mon, 13 May 2024 17:33:37 +0400 Subject: [PATCH 14/45] rej-sample-v2.1 --- chemlactica/mol_opt/optimization.py | 5 ++++ chemlactica/mol_opt/tunning.py | 38 ++++++++++++++++++++++++++--- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index eb38083..2fa2d32 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -160,7 +160,12 @@ def optimize( # round_entries = list(np.unique(round_entries))[::-1] # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: +<<<<<<< HEAD if tol_level >= 3 and num_iter > initial_num_iter: +======= + # if num_iter > initial_num_iter and num_iter % 3 == 0: + if tol_level >= 2: +>>>>>>> 44cfb4b (rej-sample-v2.1) training_entries = molecule_pool.molecule_entries print(f"Num of train examples {len(training_entries)}.") file.write("Training entries\n") diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py index 7ff56e5..a4c524c 100644 --- a/chemlactica/mol_opt/tunning.py +++ b/chemlactica/mol_opt/tunning.py @@ -1,7 +1,34 @@ +from transformers.trainer_callback import TrainerControl, TrainerState from trl import SFTTrainer, DataCollatorForCompletionOnlyLM -from transformers import TrainingArguments, get_polynomial_decay_schedule_with_warmup +from transformers import TrainingArguments, get_polynomial_decay_schedule_with_warmup, TrainerCallback from torch.optim.lr_scheduler import ConstantLR import torch +import math + + +class CustomSFTTrainer(SFTTrainer): + + def __init__(self, *args, patience, toll, **kwargs): + super().__init__(*args, **kwargs) + self.patience = patience + self.initial_pat = patience + self.toll = toll + self.best_loss = math.inf + + def log(self, logs) -> None: + if logs.get("loss"): + curr_loss = logs["loss"] + if curr_loss > self.best_loss - self.toll: + self.patience -= 1 + print(f"loss did not improve, patience {self.patience}") + else: + print("loss improved") + self.best_loss = curr_loss + self.patience = self.initial_pat + if self.patience == 0: + print("The loss does not improve, stop training.") + self.control.should_training_stop = True + return super().log(logs) def supervised_fine_tune( @@ -19,7 +46,8 @@ def supervised_fine_tune( dataloader_pin_memory=True, dataloader_num_workers=config["dataloader_num_workers"], gradient_accumulation_steps=config["gradient_accumulation_steps"], - logging_steps=1 + logging_steps=1, + metric_for_best_model="loss", ) optimizer = torch.optim.AdamW( model.parameters(), @@ -37,7 +65,7 @@ def supervised_fine_tune( collator = DataCollatorForCompletionOnlyLM( config["response_template"], tokenizer=tokenizer ) - trainer = SFTTrainer( + trainer = CustomSFTTrainer( model=model, train_dataset=train_dataset, formatting_func=config["formatting_func"], @@ -46,6 +74,8 @@ def supervised_fine_tune( tokenizer=tokenizer, max_seq_length=config["max_seq_length"], # data_collator=collator, - optimizers=[optimizer, lr_scheduler] + optimizers=[optimizer, lr_scheduler], + patience=2, + toll=0.0001 ) trainer.train() From 52b3a641811724cde945779bf49229701143b51f Mon Sep 17 00:00:00 2001 From: tigranfah Date: Tue, 14 May 2024 12:05:14 +0400 Subject: [PATCH 15/45] merge --- chemlactica/mol_opt/optimization.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 2fa2d32..8d02b94 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -57,27 +57,6 @@ def create_molecule_entry(output_text): return None -# def query_molecule_properties(model, tokenizer, smiles, property_tag, prop_pred_kwargs): -# property_start_tag, property_end_tag = f"[{property_tag}]", f"[/{property_tag}]" -# prompts = [f"[START_SMILES]{smiles}[END_SMILES][{property_tag}]"] -# data = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) -# del data["token_type_ids"] -# outputs = model.generate( -# **data, -# **prop_pred_kwargs -# ) -# predicted_property_values = [] -# output_texts = tokenizer.batch_decode(outputs) -# for output_text in output_texts: -# start_ind = output_text.find(property_start_tag) -# end_ind = output_text.find(property_end_tag) -# if start_ind != -1 and end_ind != -1: -# predicted_property_values.append(output_text[start_ind+len(property_start_tag):end_ind]) -# else: -# predicted_property_values.append(None) -# return predicted_property_values - - def optimize( model, tokenizer, oracle, config @@ -160,12 +139,7 @@ def optimize( # round_entries = list(np.unique(round_entries))[::-1] # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: -<<<<<<< HEAD if tol_level >= 3 and num_iter > initial_num_iter: -======= - # if num_iter > initial_num_iter and num_iter % 3 == 0: - if tol_level >= 2: ->>>>>>> 44cfb4b (rej-sample-v2.1) training_entries = molecule_pool.molecule_entries print(f"Num of train examples {len(training_entries)}.") file.write("Training entries\n") From cff4db79cd2949fd0f7048654da8a29275f3921c Mon Sep 17 00:00:00 2001 From: tigranfah Date: Wed, 15 May 2024 20:30:49 +0400 Subject: [PATCH 16/45] rej-sample-v2 refac --- chemlactica/mol_opt/optimization.py | 154 ++++++++++++++-------------- chemlactica/mol_opt/utils.py | 115 ++++++++++++++++----- 2 files changed, 166 insertions(+), 103 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 8d02b94..50fbad7 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -1,3 +1,4 @@ +from typing import List import torch from datasets import Dataset import multiprocessing @@ -7,34 +8,35 @@ import random import numpy as np from transformers import OPTForCausalLM -from chemlactica.mol_opt.utils import MoleculeEntry, MoleculePool, generate_random_number +from chemlactica.mol_opt.utils import OptimEntry, MoleculeEntry, Pool, generate_random_number, tanimoto_dist_func from chemlactica.mol_opt.tunning import supervised_fine_tune -def create_optimization_prompts(num_prompts, molecule_pool, max_mols_in_prompt: int, strategy: str, eos_token: str, post_processor=None): - prompts = [] - for i in range(num_prompts): - similars_in_prompt = molecule_pool.random_subset(max_mols_in_prompt) - prompt = eos_token - oracle_scores_of_mols_in_prompt = [] - for mol in similars_in_prompt: - if "default" in strategy: - prompt += f"[SIMILAR]{mol.smiles} {generate_random_number(0.8, 0.9):.2f}[/SIMILAR]" - elif "rej-sample" in strategy: - prompt += f"[ORACLE_SCORE]{mol.score:.2f}[/ORACLE_SCORE][START_SMILES]{mol.smiles}[END_SMILES]" - # prompt += f"[START_SMILES]{mol.smiles}[END_SMILES]" - oracle_scores_of_mols_in_prompt.append(mol.score) - if post_processor: - prompt = post_processor(prompt) - q_0_9 = np.quantile(oracle_scores_of_mols_in_prompt, 0.9) if oracle_scores_of_mols_in_prompt else 0 - required_oracle_score = generate_random_number(q_0_9, 1.0) # TODO: change the hard coded 1.0 - if "default" in strategy: - prompt += f"[START_SMILES]" - elif "rej-sample" in strategy: - prompt += f"[ORACLE_SCORE]{required_oracle_score:.2f}[/ORACLE_SCORE][START_SMILES]" - - prompts.append(prompt) - return prompts +def create_similar_mol_entries(pool, mol_entry, num_similars): + similar_entries = [e.last_entry for e in pool.random_subset(num_similars + 1)] + count = 0 + valid_similar_entries = [] + for similar_entry in similar_entries: + if count >= num_similars: + break + if similar_entry == mol_entry: + continue + valid_similar_entries.append(similar_entry) + count += 1 + return valid_similar_entries + + +def create_optimization_entries(num_entries, pool, config): + optim_entries = [] + for i in range(num_entries): + mol_entries = [e.last_entry for e in pool.random_subset(config["num_mols"])] + entries = [] + for mol_entry in mol_entries: + similar_mol_entries = create_similar_mol_entries(pool, mol_entry, num_similars=config["num_similars"]) + mol_entry.add_props["similar_mol_entries"] = similar_mol_entries + entries.append(mol_entry) + optim_entries.append(OptimEntry(None, entries)) + return optim_entries def create_molecule_entry(output_text): @@ -44,6 +46,8 @@ def create_molecule_entry(output_text): if start_ind == -1 or end_ind == -1: return None generated_smiles = output_text[start_ind+len(start_smiles_tag):end_ind] + if len(generated_smiles) == 0: + return None for output in output_text.split(start_smiles_tag)[:-1]: smiles_in_prompt = output.split(end_smiles_tag)[0] @@ -64,26 +68,28 @@ def optimize( file = open(config["log_dir"], "w") print("config", config) # print("molecule generation arguments", config["generation_config"]) - molecule_pool = MoleculePool(config["molecule_pool_size"]) - - if "rej-sample" in config["strategy"]: - all_entries = {} + pool = Pool(config["pool_size"]) max_score = 0 tol_level = 0 num_iter = 0 while True: model.eval() - current_entries = [] - while len(current_entries) < config["num_gens_per_iter"]: - prompts = create_optimization_prompts( - config["num_gens_per_iter"], - molecule_pool, - max_mols_in_prompt=config["max_mols_in_prompt"], - eos_token=config["eos_token"], - strategy=config["strategy"], - post_processor=config.get("prompts_post_processor") + iter_optim_entries: List[OptimEntry] = [] + while len(iter_optim_entries) < config["num_gens_per_iter"]: + optim_entries = create_optimization_entries( + config["num_gens_per_iter"], pool, + config=config ) + for i in range(len(optim_entries)): + last_entry = MoleculeEntry(smiles="") + last_entry.add_props["similar_mol_entries"] = create_similar_mol_entries( + pool, last_entry, config["num_similars"] + ) + optim_entries[i].last_entry = last_entry + + prompts = [optim_entry.to_prompt(is_generation=True, config=config) for optim_entry in optim_entries] + output_texts = [] for i in range(0, len(prompts), config["generation_batch_size"]): prompt_batch = prompts[i: min(len(prompts), i + config["generation_batch_size"])] @@ -100,60 +106,63 @@ def optimize( torch.cuda.empty_cache() output_texts.extend(tokenizer.batch_decode(output)) + current_mol_entries = [] + current_optim_entries = [] with multiprocessing.Pool(processes=config["num_processes"]) as pol: for i, entry in enumerate(pol.map(create_molecule_entry, output_texts)): if entry: - if getattr(oracle, 'takes_entry', False): - oracle_score = oracle(entry) - else: - oracle_score = oracle(entry.smiles) - entry.score = oracle_score - entry.add_props["prompt"] = prompts[i] - current_entries.append(entry) - file.write(f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n") - if entry.score > max_score + 0.01: - max_score = entry.score - tol_level = 0 - if oracle.finish or len(current_entries) >= config["num_gens_per_iter"]: - break - - # print(num_iter, len(current_entries)) + current_mol_entries.append(entry) + current_optim_entries.append(optim_entries[i]) + + if getattr(oracle, "takes_entry", False): + oracle_scores = oracle(current_mol_entries) + else: + oracle_scores = oracle([e.smiles for e in current_mol_entries]) + for i, oracle_score in enumerate(oracle_scores): + entry = current_mol_entries[i] + entry.score = oracle_score + entry.add_props["similar_mol_entries"] = current_optim_entries[i].last_entry.add_props["similar_mol_entries"] + current_optim_entries[i].last_entry = entry + iter_optim_entries.append(current_optim_entries[i]) + file.write(f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n") + if entry.score > max_score + 0.01: + max_score = entry.score + tol_level = 0 + if oracle.finish or len(iter_optim_entries) >= config["num_gens_per_iter"]: + break + + if oracle.finish: break if oracle.finish: break - current_entries = list(np.unique(current_entries))[::-1] initial_num_iter = num_iter num_iter = len(oracle.mol_buffer) // config["num_gens_per_iter"] print("num_iter: ", num_iter) # diversity_score = 1 / (1 + math.log(1 + repeated_max_score) / math.log(10)) - molecule_pool.add(current_entries) - file.write("Molecule pool\n") - for i, mol in enumerate(molecule_pool.molecule_entries): - file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") + pool.add(iter_optim_entries) + file.write("Pool\n") + for i, optim_entry in enumerate(pool.optim_entries): + file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") - if "rej-sample" in config["strategy"]: + if "rej-sample-v2" in config["strategy"]: # round_entries.extend(current_entries) # round_entries = list(np.unique(round_entries))[::-1] # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: - if tol_level >= 3 and num_iter > initial_num_iter: - training_entries = molecule_pool.molecule_entries + if num_iter % 3 == 0 and num_iter > initial_num_iter: + training_entries = pool.optim_entries print(f"Num of train examples {len(training_entries)}.") file.write("Training entries\n") - for i, mol in enumerate(training_entries): - file.write(f"\t{i} smiles: {mol.smiles}, score: {mol.score:.4f}\n") - - def create_training_sample(entry): - sample = entry.add_props["prompt"] - return sample + f"[START_SMILES]{entry.smiles}[END_SMILES]" + for i, optim_entry in enumerate(training_entries): + file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") train_dataset = Dataset.from_dict({ "sample": [ - create_training_sample(entry) - for entry in training_entries + optim_entry.to_prompt(is_generation=False, config=config) + for optim_entry in training_entries ] }) train_dataset.shuffle(seed=42) @@ -161,9 +170,4 @@ def create_training_sample(entry): supervised_fine_tune(model, tokenizer, train_dataset, config["rej_sample_config"]) gc.collect() torch.cuda.empty_cache() - tol_level = 0 - if "pool-dump" in config["strategy"] and tol_level >= 10: - num_to_dump = int(len(molecule_pool) * config["pool_dump_config"]["dump_perc"]) - molecule_pool.random_dump(num_to_dump) - file.write(f"Dump {num_to_dump} random elements from pool, num pool mols {len(molecule_pool)}\n") - tol_level = 0 \ No newline at end of file + tol_level = 0 \ No newline at end of file diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index cfbe4f0..98bdfb0 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -54,12 +54,13 @@ def canonicalize(smiles): class MoleculeEntry: - def __init__(self, smiles, score=None, score_estimate=None, **kwargs): - self.smiles = canonicalize(smiles) - self.mol = Chem.MolFromSmiles(smiles) - self.fingerprint = get_morgan_fingerprint(self.mol) - self.score = score - self.score_estimate = score_estimate + def __init__(self, smiles, score=None, **kwargs): + self.smiles = smiles + if smiles: + self.smiles = canonicalize(smiles) + self.mol = Chem.MolFromSmiles(smiles) + self.fingerprint = get_morgan_fingerprint(self.mol) + self.score = score self.add_props = kwargs def __eq__(self, other): @@ -73,56 +74,55 @@ def __lt__(self, other): def __str__(self): return ( f"smiles: {self.smiles}, " - f"score: {round(self.score, 4) if self.score != None else 'none'}, " - f"score_estimate: {round(self.score_estimate, 4) if self.score_estimate != None else 'none'}" # noqa + f"score: {round(self.score, 4) if self.score != None else 'none'}" ) def __repr__(self): return str(self) -class MoleculePool: +class Pool: def __init__(self, size): self.size = size - self.molecule_entries: List[MoleculeEntry] = [] + self.optim_entries: List[OptimEntry] = [] - def random_dump(self, num): - for _ in range(num): - rand_ind = random.randint(0, num - 1) - self.molecule_entries.pop(rand_ind) - print(f"Dump {num} random elements from pool, num pool mols {len(self)}") + # def random_dump(self, num): + # for _ in range(num): + # rand_ind = random.randint(0, num - 1) + # self.molecule_entries.pop(rand_ind) + # print(f"Dump {num} random elements from pool, num pool mols {len(self)}") def add(self, entries: List[MoleculeEntry], diversity_score=1.0): assert type(entries) == list - self.molecule_entries.extend(entries) - self.molecule_entries.sort(reverse=True) + self.optim_entries.extend(entries) + self.optim_entries.sort(key=lambda x: x.last_entry, reverse=True) # print(f"Updating with div_score {diversity_score:.4f}") # remove doublicates - new_molecule_entries = [] - for mol in self.molecule_entries: + new_optim_entries = [] + for entry in self.optim_entries: insert = True - for m in new_molecule_entries: + for e in new_optim_entries: if ( - mol == m - or tanimoto_dist_func(mol.fingerprint, m.fingerprint) + entry == e + or tanimoto_dist_func(entry.last_entry.fingerprint, e.last_entry.fingerprint) > diversity_score ): insert = False break if insert: - new_molecule_entries.append(mol) + new_optim_entries.append(entry) - self.molecule_entries = new_molecule_entries[ - : min(len(new_molecule_entries), self.size) + self.optim_entries = new_optim_entries[ + : min(len(new_optim_entries), self.size) ] def random_subset(self, subset_size): - rand_inds = np.random.permutation(min(len(self.molecule_entries), subset_size)) - return [self.molecule_entries[i] for i in rand_inds] + rand_inds = np.random.permutation(min(len(self.optim_entries), subset_size)) + return [self.optim_entries[i] for i in rand_inds] def __len__(self): - return len(self.molecule_entries) + return len(self.optim_entries) def make_output_files_base(input_path, results_dir, run_name, config): @@ -136,3 +136,62 @@ def make_output_files_base(input_path, results_dir, run_name, config): output_dir = os.path.join(base, f"{strategy}-{v}") os.makedirs(output_dir, exist_ok=True) return output_dir + + +def create_prompt_with_similars(mol_entry: MoleculeEntry, sim_range=None): + prompt = "" + for sim_mol_entry in mol_entry.add_props["similar_mol_entries"]: + if sim_range: + prompt += f"[SIMILAR]{sim_mol_entry.smiles} {generate_random_number(sim_range[0], sim_range[1]):.2f}[/SIMILAR]" + else: + prompt += f"[SIMILAR]{sim_mol_entry.smiles} {tanimoto_dist_func(sim_mol_entry.fingerprint, mol_entry.fingerprint):.2f}[/SIMILAR]" + return prompt + + +class OptimEntry: + + def __init__(self, last_entry, mol_entries): + self.last_entry: MoleculeEntry = last_entry + self.mol_entries: List[MoleculeEntry] = mol_entries + + def to_prompt(self, is_generation, config): + prompt = "" + for mol_entry in self.mol_entries: + prompt += config["eos_token"] + if "default" in config["strategy"]: + prompt += create_prompt_with_similars(mol_entry=mol_entry) + elif "rej-sample-v2" in config["strategy"]: + prompt += create_prompt_with_similars(mol_entry=mol_entry) + prompt += f"[PROPERTY]oracle_score {mol_entry.score:.2f}[/PROPERTY]" + else: + raise Exception(f"Strategy {config['strategy']} not known.") + prompt += f"[START_SMILES]{mol_entry.smiles}[END_SMILES]" + + assert self.last_entry + prompt += config["eos_token"] + if is_generation: + prompt_with_similars = create_prompt_with_similars(self.last_entry, sim_range=config["sim_range"]) + else: + prompt_with_similars = create_prompt_with_similars(self.last_entry) + + if "default" in config["strategy"]: + prompt += prompt_with_similars + elif "rej-sample-v2" in config["strategy"]: + prompt += prompt_with_similars + if is_generation: + oracle_scores_of_mols_in_prompt = [e.score for e in self.mol_entries] + q_0_9 = np.quantile(oracle_scores_of_mols_in_prompt, 0.9) if oracle_scores_of_mols_in_prompt else 0 + desired_oracle_score = generate_random_number(q_0_9, 1.0) # TODO: change the hard coded 1.0 + oracle_score = desired_oracle_score + else: + oracle_score = self.last_entry.score + prompt += f"[PROPERTY]oracle_score {oracle_score:.2f}[/PROPERTY]" + else: + raise Exception(f"Strategy {config['strategy']} not known.") + + if is_generation: + prompt += "[START_SMILES]" + else: + prompt += f"[START_SMILES]{self.last_entry.smiles}[END_SMILES]" + + return prompt \ No newline at end of file From efd8106911d72ed46848a047b17986631de108e5 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Fri, 17 May 2024 10:40:10 +0400 Subject: [PATCH 17/45] remove dublicates if any from the optim process --- chemlactica/mol_opt/optimization.py | 5 ++--- chemlactica/mol_opt/utils.py | 8 +++++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 50fbad7..ef14b0a 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -110,7 +110,7 @@ def optimize( current_optim_entries = [] with multiprocessing.Pool(processes=config["num_processes"]) as pol: for i, entry in enumerate(pol.map(create_molecule_entry, output_texts)): - if entry: + if entry and not optim_entries[i].contains_entry(entry): current_mol_entries.append(entry) current_optim_entries.append(optim_entries[i]) @@ -130,7 +130,6 @@ def optimize( tol_level = 0 if oracle.finish or len(iter_optim_entries) >= config["num_gens_per_iter"]: break - if oracle.finish: break @@ -152,7 +151,7 @@ def optimize( # round_entries = list(np.unique(round_entries))[::-1] # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: - if num_iter % 3 == 0 and num_iter > initial_num_iter: + if num_iter % 5 == 0 and num_iter > initial_num_iter: training_entries = pool.optim_entries print(f"Num of train examples {len(training_entries)}.") file.write("Training entries\n") diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 98bdfb0..1ee6e16 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -194,4 +194,10 @@ def to_prompt(self, is_generation, config): else: prompt += f"[START_SMILES]{self.last_entry.smiles}[END_SMILES]" - return prompt \ No newline at end of file + return prompt + + def contains_entry(self, mol_entry: MoleculeEntry): + for entry in self.mol_entries: + if mol_entry == entry: + return True + return False \ No newline at end of file From cade253e9dc96f3347788ea6291344fa5ebaaf68 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Fri, 17 May 2024 12:45:12 +0400 Subject: [PATCH 18/45] add hash type for molecular entries to store --- chemlactica/mol_opt/utils.py | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 98bdfb0..25f45a7 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -105,7 +105,9 @@ def add(self, entries: List[MoleculeEntry], diversity_score=1.0): for e in new_optim_entries: if ( entry == e - or tanimoto_dist_func(entry.last_entry.fingerprint, e.last_entry.fingerprint) + or tanimoto_dist_func( + entry.last_entry.fingerprint, e.last_entry.fingerprint + ) > diversity_score ): insert = False @@ -113,9 +115,7 @@ def add(self, entries: List[MoleculeEntry], diversity_score=1.0): if insert: new_optim_entries.append(entry) - self.optim_entries = new_optim_entries[ - : min(len(new_optim_entries), self.size) - ] + self.optim_entries = new_optim_entries[: min(len(new_optim_entries), self.size)] def random_subset(self, subset_size): rand_inds = np.random.permutation(min(len(self.optim_entries), subset_size)) @@ -124,6 +124,9 @@ def random_subset(self, subset_size): def __len__(self): return len(self.optim_entries) + def __hash__(self): + return hash(self.smiles) + def make_output_files_base(input_path, results_dir, run_name, config): formatted_date_time = datetime.datetime.now().strftime("%Y-%m-%d") @@ -142,14 +145,13 @@ def create_prompt_with_similars(mol_entry: MoleculeEntry, sim_range=None): prompt = "" for sim_mol_entry in mol_entry.add_props["similar_mol_entries"]: if sim_range: - prompt += f"[SIMILAR]{sim_mol_entry.smiles} {generate_random_number(sim_range[0], sim_range[1]):.2f}[/SIMILAR]" + prompt += f"[SIMILAR]{sim_mol_entry.smiles} {generate_random_number(sim_range[0], sim_range[1]):.2f}[/SIMILAR]" # noqa else: - prompt += f"[SIMILAR]{sim_mol_entry.smiles} {tanimoto_dist_func(sim_mol_entry.fingerprint, mol_entry.fingerprint):.2f}[/SIMILAR]" + prompt += f"[SIMILAR]{sim_mol_entry.smiles} {tanimoto_dist_func(sim_mol_entry.fingerprint, mol_entry.fingerprint):.2f}[/SIMILAR]" # noqa return prompt class OptimEntry: - def __init__(self, last_entry, mol_entries): self.last_entry: MoleculeEntry = last_entry self.mol_entries: List[MoleculeEntry] = mol_entries @@ -170,28 +172,36 @@ def to_prompt(self, is_generation, config): assert self.last_entry prompt += config["eos_token"] if is_generation: - prompt_with_similars = create_prompt_with_similars(self.last_entry, sim_range=config["sim_range"]) + prompt_with_similars = create_prompt_with_similars( + self.last_entry, sim_range=config["sim_range"] + ) else: prompt_with_similars = create_prompt_with_similars(self.last_entry) - + if "default" in config["strategy"]: prompt += prompt_with_similars elif "rej-sample-v2" in config["strategy"]: prompt += prompt_with_similars if is_generation: oracle_scores_of_mols_in_prompt = [e.score for e in self.mol_entries] - q_0_9 = np.quantile(oracle_scores_of_mols_in_prompt, 0.9) if oracle_scores_of_mols_in_prompt else 0 - desired_oracle_score = generate_random_number(q_0_9, 1.0) # TODO: change the hard coded 1.0 + q_0_9 = ( + np.quantile(oracle_scores_of_mols_in_prompt, 0.9) + if oracle_scores_of_mols_in_prompt + else 0 + ) + desired_oracle_score = generate_random_number( + q_0_9, 1.0 + ) # TODO: change the hard coded 1.0 oracle_score = desired_oracle_score else: oracle_score = self.last_entry.score prompt += f"[PROPERTY]oracle_score {oracle_score:.2f}[/PROPERTY]" else: raise Exception(f"Strategy {config['strategy']} not known.") - + if is_generation: prompt += "[START_SMILES]" else: prompt += f"[START_SMILES]{self.last_entry.smiles}[END_SMILES]" - return prompt \ No newline at end of file + return prompt From a805c1bab8aba65fdadde0c14013f9f917f194c8 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Fri, 17 May 2024 13:08:21 +0400 Subject: [PATCH 19/45] add logit processors to package --- chemlactica/utils/logits_processors.py | 88 ++++++++++++++++++++++++++ chemlactica/utils/logits_utils.py | 79 +++++++++++++++++++++++ 2 files changed, 167 insertions(+) create mode 100644 chemlactica/utils/logits_processors.py create mode 100644 chemlactica/utils/logits_utils.py diff --git a/chemlactica/utils/logits_processors.py b/chemlactica/utils/logits_processors.py new file mode 100644 index 0000000..4abe0cd --- /dev/null +++ b/chemlactica/utils/logits_processors.py @@ -0,0 +1,88 @@ +from transformers import LogitsProcessor +import torch +from typing import List, Union + + +class OneOccurenceLogitsProcessor(LogitsProcessor): + def __init__(self, suppress_tokens): + self.suppress_tokens = list(suppress_tokens) + + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + for token_id in self.suppress_tokens: + if token_id in input_ids: + scores[:, token_id] = -float("inf") + return scores + + +class TunableExponentialDecayLengthPenalty(LogitsProcessor): + def __init__( + self, + exponential_decay_factors: List[float], + regulation_starts: List[int], + decay_token_ids: Union[int, List[int]], + input_ids_seq_length: int, + ): + self.regulation_starts = regulation_starts + self.regulation_list = exponential_decay_factors + if isinstance(decay_token_ids, int): + decay_token_ids = [decay_token_ids] + self.decay_token_ids = decay_token_ids + + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + cur_len = input_ids.shape[-1] + penalties = torch.zeros_like(scores) + scores_processed = scores + for token_id, regulation_factor, regulation_start in zip( + self.decay_token_ids, self.regulation_list, self.regulation_starts + ): + if cur_len > regulation_start: + penalty_idx = cur_len - regulation_start + penalty = torch.abs(scores[:, token_id]) * ( + pow(regulation_factor, penalty_idx) - 1 + ) + penalties[:, token_id] = penalty + scores_processed = scores + penalties + return scores_processed + + +class SequentialDecayProcessor(LogitsProcessor): + def __init__( + self, + exponential_decay_factors: List[float], + decay_token_ids: Union[int, List[int]], + input_ids_seq_length: int, + ): + self.regulation_list = exponential_decay_factors + if isinstance(decay_token_ids, int): + decay_token_ids = [decay_token_ids] + + self.decay_token_ids = decay_token_ids + self.regulation_starts = 99999999999999999 * len(self.decay_token_ids) + self.regulation_starts[0] = 0 + self.regulation_list = exponential_decay_factors + + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + cur_len = input_ids.shape[-1] + penalties = torch.zeros_like(scores) + scores_processed = scores + for i, (token_id, regulation_factor) in enumerate( + zip(self.decay_token_ids, self.regulation_list) + ): + if token_id in input_ids: + self.regulation_starts[i] = cur_len + + reg_start = self.regulation_starts[i] + if cur_len > reg_start: + penalty_idx = cur_len - reg_start + penalty = torch.abs(scores[:, token_id]) * ( + pow(regulation_factor, penalty_idx) - 1 + ) + penalties[:, token_id] = penalty + scores_processed = scores + penalties + return scores_processed diff --git a/chemlactica/utils/logits_utils.py b/chemlactica/utils/logits_utils.py new file mode 100644 index 0000000..21de004 --- /dev/null +++ b/chemlactica/utils/logits_utils.py @@ -0,0 +1,79 @@ +import yaml +import os +from typing import List, Any, Dict +from transformers.generation import LogitsProcessor, LogitsProcessorList +from dataclasses import dataclass, field +import importlib +import importlib.util + + +def import_local_module(module_name, relative_path): + current_dir = os.path.dirname(os.path.abspath(__file__)) + module_path = os.path.join(current_dir, relative_path) + + spec = importlib.util.spec_from_file_location(module_name, module_path) + if spec is None: + raise ImportError(f"Could not import module {module_name} from {module_path}") + + # Create a new module based on the spec + module = importlib.util.module_from_spec(spec) + + # Execute the module's code and populate the module + spec.loader.exec_module(module) + + return module + + +@dataclass +class LogitsProcessorConfig: + class_name: str + is_local: str + module: str + kwargs: Dict[str, Any] + path: str = field(default=None) + + +def instantiate_processors( + config: List[LogitsProcessorConfig], +) -> List[LogitsProcessor]: + processors = [] + for processor_config in config: + if processor_config.is_local: + module = import_local_module(processor_config.module, processor_config.path) + else: + module = importlib.import_module(processor_config.module) + processor_class = getattr(module, processor_config.class_name) + processor = processor_class(**processor_config.kwargs) + processors.append(processor) + return processors + + +def load_processor_config(file_path: str) -> List[LogitsProcessorConfig]: + with open(file_path, "r") as file: + config_data = yaml.safe_load(file) + configs = [ + LogitsProcessorConfig(**processor) + for processor in config_data["logits_processors"] + ] + return configs + + +def get_logits_processors(logits_processors_config_path=None): + # current_dir = os.path.dirname(os.path.abspath(__file__)) + # config_file_path = os.path.join(current_dir, "logit_configs","best_config.yaml") + if logits_processors_config_path: + logit_processors_config = load_processor_config(logits_processors_config_path) + logit_processors = instantiate_processors(logit_processors_config) + logit_processor_list = LogitsProcessorList(logit_processors) + return logit_processor_list + else: + return None + + +if __name__ == "__main__": + current_dir = os.path.dirname(os.path.abspath(__file__)) + config_file_path = os.path.join(current_dir, "logit_configs", "best_config.yaml") + logit_processors_config = load_processor_config(config_file_path) + logit_processors = instantiate_processors(logit_processors_config) + logit_processor_list = LogitsProcessorList(logit_processors) + print(logit_processor_list) From 20d3fe0f59d90a9546a9af64fc4dfaa1dfdc4525 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Fri, 17 May 2024 17:29:07 +0400 Subject: [PATCH 20/45] add train condition --- chemlactica/mol_opt/optimization.py | 9 ++++++--- chemlactica/mol_opt/utils.py | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index ef14b0a..de9940f 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -73,6 +73,7 @@ def optimize( max_score = 0 tol_level = 0 num_iter = 0 + prev_train_iter = 0 while True: model.eval() iter_optim_entries: List[OptimEntry] = [] @@ -138,7 +139,9 @@ def optimize( break initial_num_iter = num_iter num_iter = len(oracle.mol_buffer) // config["num_gens_per_iter"] - print("num_iter: ", num_iter) + print(f"num_iter: {num_iter}, tol_level: {tol_level}, prev_train_iter: {prev_train_iter}") + if num_iter > initial_num_iter: + tol_level += 1 # diversity_score = 1 / (1 + math.log(1 + repeated_max_score) / math.log(10)) pool.add(iter_optim_entries) @@ -151,7 +154,7 @@ def optimize( # round_entries = list(np.unique(round_entries))[::-1] # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: - if num_iter % 5 == 0 and num_iter > initial_num_iter: + if config["rej_sample_config"]["train_condition"](num_iter, tol_level, prev_train_iter): training_entries = pool.optim_entries print(f"Num of train examples {len(training_entries)}.") file.write("Training entries\n") @@ -169,4 +172,4 @@ def optimize( supervised_fine_tune(model, tokenizer, train_dataset, config["rej_sample_config"]) gc.collect() torch.cuda.empty_cache() - tol_level = 0 \ No newline at end of file + prev_train_iter = num_iter \ No newline at end of file diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 33f081e..267a531 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -54,13 +54,13 @@ def canonicalize(smiles): class MoleculeEntry: - def __init__(self, smiles, score=None, **kwargs): + def __init__(self, smiles, score=0, **kwargs): self.smiles = smiles + self.score = score if smiles: self.smiles = canonicalize(smiles) self.mol = Chem.MolFromSmiles(smiles) self.fingerprint = get_morgan_fingerprint(self.mol) - self.score = score self.add_props = kwargs def __eq__(self, other): From bfa32d75fb8643e121db45bbab299bce1b83f936 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Sat, 18 May 2024 15:40:57 +0400 Subject: [PATCH 21/45] don't give oracle score tag before the first training --- chemlactica/mol_opt/optimization.py | 9 ++++++--- chemlactica/mol_opt/utils.py | 13 ++++++++----- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index de9940f..0ed4ffc 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -89,7 +89,10 @@ def optimize( ) optim_entries[i].last_entry = last_entry - prompts = [optim_entry.to_prompt(is_generation=True, config=config) for optim_entry in optim_entries] + prompts = [ + optim_entry.to_prompt(is_generation=True, include_oracle_score=prev_train_iter != 0, config=config) + for optim_entry in optim_entries + ] output_texts = [] for i in range(0, len(prompts), config["generation_batch_size"]): @@ -154,7 +157,7 @@ def optimize( # round_entries = list(np.unique(round_entries))[::-1] # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: - if config["rej_sample_config"]["train_condition"](num_iter, tol_level, prev_train_iter): + if config["rej_sample_config"]["should_train"](num_iter, tol_level, prev_train_iter): training_entries = pool.optim_entries print(f"Num of train examples {len(training_entries)}.") file.write("Training entries\n") @@ -163,7 +166,7 @@ def optimize( train_dataset = Dataset.from_dict({ "sample": [ - optim_entry.to_prompt(is_generation=False, config=config) + optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config) for optim_entry in training_entries ] }) diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 267a531..337c7fc 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -156,21 +156,23 @@ def __init__(self, last_entry, mol_entries): self.last_entry: MoleculeEntry = last_entry self.mol_entries: List[MoleculeEntry] = mol_entries - def to_prompt(self, is_generation, config): + def to_prompt(self, is_generation: bool, include_oracle_score: bool, config): prompt = "" + prompt = config["eos_token"] for mol_entry in self.mol_entries: - prompt += config["eos_token"] + # prompt += config["eos_token"] if "default" in config["strategy"]: prompt += create_prompt_with_similars(mol_entry=mol_entry) elif "rej-sample-v2" in config["strategy"]: prompt += create_prompt_with_similars(mol_entry=mol_entry) - prompt += f"[PROPERTY]oracle_score {mol_entry.score:.2f}[/PROPERTY]" + if include_oracle_score: + prompt += f"[ORACLE_SCORE]{mol_entry.score:.2f}[/ORACLE_SCORE]" else: raise Exception(f"Strategy {config['strategy']} not known.") prompt += f"[START_SMILES]{mol_entry.smiles}[END_SMILES]" assert self.last_entry - prompt += config["eos_token"] + # prompt += config["eos_token"] if is_generation: prompt_with_similars = create_prompt_with_similars( self.last_entry, sim_range=config["sim_range"] @@ -195,7 +197,8 @@ def to_prompt(self, is_generation, config): oracle_score = desired_oracle_score else: oracle_score = self.last_entry.score - prompt += f"[PROPERTY]oracle_score {oracle_score:.2f}[/PROPERTY]" + if include_oracle_score: + prompt += f"[ORACLE_SCORE]{oracle_score:.2f}[/ORACLE_SCORE]" else: raise Exception(f"Strategy {config['strategy']} not known.") From d51cbfbb7bf457b6c3e37f97cd3f790797687041 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Mon, 20 May 2024 14:38:02 +0400 Subject: [PATCH 22/45] add validation set to fine tunning --- chemlactica/mol_opt/optimization.py | 28 ++++++--- chemlactica/mol_opt/tunning.py | 88 +++++++++++++++++++---------- chemlactica/mol_opt/utils.py | 48 +++++++++++++--- 3 files changed, 120 insertions(+), 44 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 0ed4ffc..0b0f167 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -68,7 +68,7 @@ def optimize( file = open(config["log_dir"], "w") print("config", config) # print("molecule generation arguments", config["generation_config"]) - pool = Pool(config["pool_size"]) + pool = Pool(config["pool_size"], validation_perc=config["validation_perc"]) max_score = 0 tol_level = 0 @@ -129,7 +129,7 @@ def optimize( current_optim_entries[i].last_entry = entry iter_optim_entries.append(current_optim_entries[i]) file.write(f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n") - if entry.score > max_score + 0.01: + if entry.score > max_score: max_score = entry.score tol_level = 0 if oracle.finish or len(iter_optim_entries) >= config["num_gens_per_iter"]: @@ -158,21 +158,35 @@ def optimize( # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: if config["rej_sample_config"]["should_train"](num_iter, tol_level, prev_train_iter): - training_entries = pool.optim_entries - print(f"Num of train examples {len(training_entries)}.") + train_entries, validation_entries = pool.get_train_valid_entries() + print(f"Num of training examples: {len(train_entries)}, num of validation examples: {len(validation_entries)}.") file.write("Training entries\n") - for i, optim_entry in enumerate(training_entries): + for i, optim_entry in enumerate(train_entries): + file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") + file.write("Validation entries\n") + for i, optim_entry in enumerate(validation_entries): file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") train_dataset = Dataset.from_dict({ "sample": [ optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config) - for optim_entry in training_entries + for optim_entry in train_entries + ] + }) + validation_dataset = Dataset.from_dict({ + "sample": [ + optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config) + for optim_entry in validation_entries ] }) train_dataset.shuffle(seed=42) + validation_dataset.shuffle(seed=42) config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"] - supervised_fine_tune(model, tokenizer, train_dataset, config["rej_sample_config"]) + supervised_fine_tune( + model, tokenizer, + train_dataset, validation_dataset, + config["rej_sample_config"] + ) gc.collect() torch.cuda.empty_cache() prev_train_iter = num_iter \ No newline at end of file diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py index a4c524c..d57d88f 100644 --- a/chemlactica/mol_opt/tunning.py +++ b/chemlactica/mol_opt/tunning.py @@ -1,39 +1,65 @@ -from transformers.trainer_callback import TrainerControl, TrainerState +from transformers.trainer_callback import TrainerControl, TrainerState, TrainerCallback from trl import SFTTrainer, DataCollatorForCompletionOnlyLM -from transformers import TrainingArguments, get_polynomial_decay_schedule_with_warmup, TrainerCallback +from transformers import TrainingArguments, get_polynomial_decay_schedule_with_warmup, EarlyStoppingCallback from torch.optim.lr_scheduler import ConstantLR import torch import math -class CustomSFTTrainer(SFTTrainer): +class CustomEarlyStopCallback(TrainerCallback): - def __init__(self, *args, patience, toll, **kwargs): - super().__init__(*args, **kwargs) - self.patience = patience - self.initial_pat = patience - self.toll = toll - self.best_loss = math.inf + def __init__(self, early_stopping_patience: int, early_stopping_threshold: float) -> None: + super().__init__() + self.best_valid_loss = math.inf + self.early_stopping_patience = early_stopping_patience + self.current_patiance = 0 + self.early_stopping_threshold = early_stopping_threshold - def log(self, logs) -> None: - if logs.get("loss"): - curr_loss = logs["loss"] - if curr_loss > self.best_loss - self.toll: - self.patience -= 1 - print(f"loss did not improve, patience {self.patience}") - else: - print("loss improved") - self.best_loss = curr_loss - self.patience = self.initial_pat - if self.patience == 0: - print("The loss does not improve, stop training.") - self.control.should_training_stop = True - return super().log(logs) + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + self.best_valid_loss = math.inf + self.current_patiance = 0 + return super().on_train_begin(args, state, control, **kwargs) + + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs): + if metrics["eval_loss"] >= self.best_valid_loss - self.early_stopping_threshold: + self.current_patiance += 1 + else: + self.current_patiance = 0 + self.best_valid_loss = metrics["eval_loss"] + print(f"Early Stopping patiance: {self.current_patiance}/{self.early_stopping_patience}") + if self.current_patiance >= self.early_stopping_patience: + control.should_training_stop = True + return super().on_evaluate(args, state, control, **kwargs) + + +# class CustomSFTTrainer(SFTTrainer): +# + # def __init__(self, *args, patience, toll, **kwargs): + # super().__init__(*args, **kwargs) + # self.patience = patience + # self.initial_pat = patience + # self.toll = toll + # self.best_loss = math.inf + + # def log(self, logs) -> None: + # if logs.get("loss"): + # curr_loss = logs["loss"] + # if curr_loss > self.best_loss - self.toll: + # self.patience -= 1 + # print(f"loss did not improve, patience {self.patience}") + # else: + # print("loss improved") + # self.best_loss = curr_loss + # self.patience = self.initial_pat + # if self.patience == 0: + # print("The loss does not improve, stop training.") + # self.control.should_training_stop = True + # return super().log(logs) def supervised_fine_tune( model, tokenizer, - train_dataset, config + train_dataset, validation_dataset, config ): model.train() training_args = TrainingArguments( @@ -41,13 +67,13 @@ def supervised_fine_tune( per_device_train_batch_size=config["train_batch_size"], max_grad_norm=config["global_gradient_norm"], num_train_epochs=config["num_train_epochs"], - evaluation_strategy="no", + evaluation_strategy="epoch", dataloader_drop_last=False, dataloader_pin_memory=True, dataloader_num_workers=config["dataloader_num_workers"], gradient_accumulation_steps=config["gradient_accumulation_steps"], logging_steps=1, - metric_for_best_model="loss", + metric_for_best_model="loss" ) optimizer = torch.optim.AdamW( model.parameters(), @@ -65,9 +91,14 @@ def supervised_fine_tune( collator = DataCollatorForCompletionOnlyLM( config["response_template"], tokenizer=tokenizer ) - trainer = CustomSFTTrainer( + early_stopping_callback = CustomEarlyStopCallback( + early_stopping_patience=1, + early_stopping_threshold=0.001 + ) + trainer = SFTTrainer( model=model, train_dataset=train_dataset, + eval_dataset=validation_dataset, formatting_func=config["formatting_func"], args=training_args, packing=config["packing"], @@ -75,7 +106,6 @@ def supervised_fine_tune( max_seq_length=config["max_seq_length"], # data_collator=collator, optimizers=[optimizer, lr_scheduler], - patience=2, - toll=0.0001 + callbacks=[early_stopping_callback], ) trainer.train() diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 337c7fc..6037ebb 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -82,9 +82,10 @@ def __repr__(self): class Pool: - def __init__(self, size): + def __init__(self, size, validation_perc: float): self.size = size self.optim_entries: List[OptimEntry] = [] + self.num_validation_entries = int(size * validation_perc) # def random_dump(self, num): # for _ in range(num): @@ -92,12 +93,11 @@ def __init__(self, size): # self.molecule_entries.pop(rand_ind) # print(f"Dump {num} random elements from pool, num pool mols {len(self)}") - def add(self, entries: List[MoleculeEntry], diversity_score=1.0): + def add(self, entries: List, diversity_score=1.0): assert type(entries) == list self.optim_entries.extend(entries) self.optim_entries.sort(key=lambda x: x.last_entry, reverse=True) - # print(f"Updating with div_score {diversity_score:.4f}") # remove doublicates new_optim_entries = [] for entry in self.optim_entries: @@ -116,6 +116,34 @@ def add(self, entries: List[MoleculeEntry], diversity_score=1.0): new_optim_entries.append(entry) self.optim_entries = new_optim_entries[: min(len(new_optim_entries), self.size)] + curr_num_validation_entries = sum([entry.entry_status == EntryStatus.valid for entry in self.optim_entries]) + + i = 0 + while curr_num_validation_entries < self.num_validation_entries: + if self.optim_entries[i].entry_status == EntryStatus.none: + self.optim_entries[i].entry_status = EntryStatus.valid + curr_num_validation_entries += 1 + i += 1 + + for j in range(i, len(self.optim_entries)): + if self.optim_entries[j].entry_status == EntryStatus.none: + self.optim_entries[j].entry_status = EntryStatus.train + + curr_num_validation_entries = sum([entry.entry_status == EntryStatus.valid for entry in self.optim_entries]) + assert curr_num_validation_entries == min(len(self.optim_entries), self.num_validation_entries) + + def get_train_valid_entries(self): + train_entries = [] + valid_entries = [] + for entry in self.optim_entries: + if entry.entry_status == EntryStatus.train: + train_entries.append(entry) + elif entry.entry_status == EntryStatus.valid: + valid_entries.append(entry) + else: + raise Exception(f"EntryStatus of an entry in pool cannot be {entry.entry_status}.") + assert min(len(self.optim_entries), self.num_validation_entries) == len(valid_entries) + return train_entries, valid_entries def random_subset(self, subset_size): rand_inds = np.random.permutation(min(len(self.optim_entries), subset_size)) @@ -124,9 +152,6 @@ def random_subset(self, subset_size): def __len__(self): return len(self.optim_entries) - def __hash__(self): - return hash(self.smiles) - def make_output_files_base(input_path, results_dir, run_name, config): formatted_date_time = datetime.datetime.now().strftime("%Y-%m-%d") @@ -151,10 +176,17 @@ def create_prompt_with_similars(mol_entry: MoleculeEntry, sim_range=None): return prompt +class EntryStatus: + none = 0 + train = 1 + valid = 2 + + class OptimEntry: def __init__(self, last_entry, mol_entries): self.last_entry: MoleculeEntry = last_entry self.mol_entries: List[MoleculeEntry] = mol_entries + self.entry_status: EntryStatus = EntryStatus.none def to_prompt(self, is_generation: bool, include_oracle_score: bool, config): prompt = "" @@ -192,8 +224,8 @@ def to_prompt(self, is_generation: bool, include_oracle_score: bool, config): else 0 ) desired_oracle_score = generate_random_number( - q_0_9, 1.0 - ) # TODO: change the hard coded 1.0 + q_0_9, config["max_possible_oracle_score"] + ) oracle_score = desired_oracle_score else: oracle_score = self.last_entry.score From 1502e65daff46a864de574b01682d113d8accf4f Mon Sep 17 00:00:00 2001 From: tigranfah Date: Mon, 20 May 2024 17:50:19 +0400 Subject: [PATCH 23/45] replace [ORACLE_SCORE] with [PROPERTY] tag --- chemlactica/mol_opt/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 6037ebb..6eeb17a 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -198,7 +198,7 @@ def to_prompt(self, is_generation: bool, include_oracle_score: bool, config): elif "rej-sample-v2" in config["strategy"]: prompt += create_prompt_with_similars(mol_entry=mol_entry) if include_oracle_score: - prompt += f"[ORACLE_SCORE]{mol_entry.score:.2f}[/ORACLE_SCORE]" + prompt += f"[PROPERTY]oracle_score {mol_entry.score:.2f}[/PROPERTY]" else: raise Exception(f"Strategy {config['strategy']} not known.") prompt += f"[START_SMILES]{mol_entry.smiles}[END_SMILES]" @@ -230,7 +230,7 @@ def to_prompt(self, is_generation: bool, include_oracle_score: bool, config): else: oracle_score = self.last_entry.score if include_oracle_score: - prompt += f"[ORACLE_SCORE]{oracle_score:.2f}[/ORACLE_SCORE]" + prompt += f"[PROPERTY]oracle_score {oracle_score:.2f}[/PROPERTY]" else: raise Exception(f"Strategy {config['strategy']} not known.") From 188f3847a45b65947d27dc40cc7d9d96c1ba5dee Mon Sep 17 00:00:00 2001 From: tigranfah Date: Tue, 21 May 2024 14:02:02 +0400 Subject: [PATCH 24/45] add additional properties --- chemlactica/mol_opt/optimization.py | 24 ++++++++++++++---------- chemlactica/mol_opt/tunning.py | 3 --- chemlactica/mol_opt/utils.py | 17 +++++++++++++++-- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 0b0f167..5c8ad33 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -6,6 +6,7 @@ import math import tqdm import random +from functools import partial import numpy as np from transformers import OPTForCausalLM from chemlactica.mol_opt.utils import OptimEntry, MoleculeEntry, Pool, generate_random_number, tanimoto_dist_func @@ -33,7 +34,7 @@ def create_optimization_entries(num_entries, pool, config): entries = [] for mol_entry in mol_entries: similar_mol_entries = create_similar_mol_entries(pool, mol_entry, num_similars=config["num_similars"]) - mol_entry.add_props["similar_mol_entries"] = similar_mol_entries + mol_entry.similar_mol_entries = similar_mol_entries entries.append(mol_entry) optim_entries.append(OptimEntry(None, entries)) return optim_entries @@ -49,21 +50,19 @@ def create_molecule_entry(output_text): if len(generated_smiles) == 0: return None - for output in output_text.split(start_smiles_tag)[:-1]: - smiles_in_prompt = output.split(end_smiles_tag)[0] - if generated_smiles == smiles_in_prompt: - return None try: - return MoleculeEntry( + molecule = MoleculeEntry( smiles=generated_smiles, ) + return molecule except: return None def optimize( model, tokenizer, - oracle, config + oracle, config, + additional_properties=[] ): file = open(config["log_dir"], "w") print("config", config) @@ -84,9 +83,11 @@ def optimize( ) for i in range(len(optim_entries)): last_entry = MoleculeEntry(smiles="") - last_entry.add_props["similar_mol_entries"] = create_similar_mol_entries( + last_entry.similar_mol_entries = create_similar_mol_entries( pool, last_entry, config["num_similars"] ) + for prop_name, prop_spec in additional_properties.items(): + last_entry.add_props[prop_name] = prop_spec optim_entries[i].last_entry = last_entry prompts = [ @@ -125,7 +126,10 @@ def optimize( for i, oracle_score in enumerate(oracle_scores): entry = current_mol_entries[i] entry.score = oracle_score - entry.add_props["similar_mol_entries"] = current_optim_entries[i].last_entry.add_props["similar_mol_entries"] + entry.similar_mol_entries = current_optim_entries[i].last_entry.similar_mol_entries + for prop_name, prop_spec in additional_properties.items(): + entry.add_props[prop_name] = prop_spec + entry.add_props[prop_name]["value"] = entry.add_props[prop_name]["calculate_value"](entry) current_optim_entries[i].last_entry = entry iter_optim_entries.append(current_optim_entries[i]) file.write(f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n") @@ -142,9 +146,9 @@ def optimize( break initial_num_iter = num_iter num_iter = len(oracle.mol_buffer) // config["num_gens_per_iter"] - print(f"num_iter: {num_iter}, tol_level: {tol_level}, prev_train_iter: {prev_train_iter}") if num_iter > initial_num_iter: tol_level += 1 + print(f"num_iter: {num_iter}, tol_level: {tol_level}, prev_train_iter: {prev_train_iter}") # diversity_score = 1 / (1 + math.log(1 + repeated_max_score) / math.log(10)) pool.add(iter_optim_entries) diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py index d57d88f..d04c639 100644 --- a/chemlactica/mol_opt/tunning.py +++ b/chemlactica/mol_opt/tunning.py @@ -88,9 +88,6 @@ def supervised_fine_tune( lr_end=0.999 * config["max_learning_rate"], power=1.0, ) - collator = DataCollatorForCompletionOnlyLM( - config["response_template"], tokenizer=tokenizer - ) early_stopping_callback = CustomEarlyStopCallback( early_stopping_patience=1, early_stopping_threshold=0.001 diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 6eeb17a..84d7589 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -57,6 +57,7 @@ class MoleculeEntry: def __init__(self, smiles, score=0, **kwargs): self.smiles = smiles self.score = score + self.similar_mol_entries = [] if smiles: self.smiles = canonicalize(smiles) self.mol = Chem.MolFromSmiles(smiles) @@ -168,7 +169,7 @@ def make_output_files_base(input_path, results_dir, run_name, config): def create_prompt_with_similars(mol_entry: MoleculeEntry, sim_range=None): prompt = "" - for sim_mol_entry in mol_entry.add_props["similar_mol_entries"]: + for sim_mol_entry in mol_entry.similar_mol_entries: if sim_range: prompt += f"[SIMILAR]{sim_mol_entry.smiles} {generate_random_number(sim_range[0], sim_range[1]):.2f}[/SIMILAR]" # noqa else: @@ -188,7 +189,10 @@ def __init__(self, last_entry, mol_entries): self.mol_entries: List[MoleculeEntry] = mol_entries self.entry_status: EntryStatus = EntryStatus.none - def to_prompt(self, is_generation: bool, include_oracle_score: bool, config): + def to_prompt( + self, is_generation: bool, + include_oracle_score: bool, config, + ): prompt = "" prompt = config["eos_token"] for mol_entry in self.mol_entries: @@ -201,6 +205,8 @@ def to_prompt(self, is_generation: bool, include_oracle_score: bool, config): prompt += f"[PROPERTY]oracle_score {mol_entry.score:.2f}[/PROPERTY]" else: raise Exception(f"Strategy {config['strategy']} not known.") + for prop_name, prop_spec in mol_entry.add_props.items(): + prompt += f"{prop_spec['start_tag']}{prop_spec['value']}{prop_spec['end_tag']}" prompt += f"[START_SMILES]{mol_entry.smiles}[END_SMILES]" assert self.last_entry @@ -234,6 +240,9 @@ def to_prompt(self, is_generation: bool, include_oracle_score: bool, config): else: raise Exception(f"Strategy {config['strategy']} not known.") + for prop_name, prop_spec in self.last_entry.add_props.items(): + prompt += prop_spec["start_tag"] + prop_spec["infer_value"](self.last_entry) + prop_spec["end_tag"] + if is_generation: prompt += "[START_SMILES]" else: @@ -245,4 +254,8 @@ def contains_entry(self, mol_entry: MoleculeEntry): for entry in self.mol_entries: if mol_entry == entry: return True + for sim_entry in entry.similar_mol_entries: + if mol_entry == sim_entry: + return True + return False From d95e4d4dcce4d28fe7f58ac471a90ceae2a69769 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Tue, 21 May 2024 14:44:09 +0400 Subject: [PATCH 25/45] correct properties order --- chemlactica/mol_opt/optimization.py | 3 +-- chemlactica/mol_opt/utils.py | 21 ++++++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 5c8ad33..0e5e530 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -62,7 +62,7 @@ def create_molecule_entry(output_text): def optimize( model, tokenizer, oracle, config, - additional_properties=[] + additional_properties={} ): file = open(config["log_dir"], "w") print("config", config) @@ -94,7 +94,6 @@ def optimize( optim_entry.to_prompt(is_generation=True, include_oracle_score=prev_train_iter != 0, config=config) for optim_entry in optim_entries ] - output_texts = [] for i in range(0, len(prompts), config["generation_batch_size"]): prompt_batch = prompts[i: min(len(prompts), i + config["generation_batch_size"])] diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 84d7589..256b4c2 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -197,16 +197,18 @@ def to_prompt( prompt = config["eos_token"] for mol_entry in self.mol_entries: # prompt += config["eos_token"] + prompt += create_prompt_with_similars(mol_entry=mol_entry) + + for prop_name, prop_spec in mol_entry.add_props.items(): + prompt += f"{prop_spec['start_tag']}{prop_spec['value']}{prop_spec['end_tag']}" + if "default" in config["strategy"]: - prompt += create_prompt_with_similars(mol_entry=mol_entry) + pass elif "rej-sample-v2" in config["strategy"]: - prompt += create_prompt_with_similars(mol_entry=mol_entry) if include_oracle_score: prompt += f"[PROPERTY]oracle_score {mol_entry.score:.2f}[/PROPERTY]" else: raise Exception(f"Strategy {config['strategy']} not known.") - for prop_name, prop_spec in mol_entry.add_props.items(): - prompt += f"{prop_spec['start_tag']}{prop_spec['value']}{prop_spec['end_tag']}" prompt += f"[START_SMILES]{mol_entry.smiles}[END_SMILES]" assert self.last_entry @@ -218,10 +220,14 @@ def to_prompt( else: prompt_with_similars = create_prompt_with_similars(self.last_entry) + prompt += prompt_with_similars + + for prop_name, prop_spec in self.last_entry.add_props.items(): + prompt += prop_spec["start_tag"] + prop_spec["infer_value"](self.last_entry) + prop_spec["end_tag"] + if "default" in config["strategy"]: - prompt += prompt_with_similars + pass elif "rej-sample-v2" in config["strategy"]: - prompt += prompt_with_similars if is_generation: oracle_scores_of_mols_in_prompt = [e.score for e in self.mol_entries] q_0_9 = ( @@ -240,9 +246,6 @@ def to_prompt( else: raise Exception(f"Strategy {config['strategy']} not known.") - for prop_name, prop_spec in self.last_entry.add_props.items(): - prompt += prop_spec["start_tag"] + prop_spec["infer_value"](self.last_entry) + prop_spec["end_tag"] - if is_generation: prompt += "[START_SMILES]" else: From 800278b455bab84a173bbcd0b44964a9cea492c1 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Wed, 22 May 2024 01:03:58 +0400 Subject: [PATCH 26/45] fix entry dublicates in pool issue --- chemlactica/mol_opt/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 256b4c2..aa61c93 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -105,7 +105,7 @@ def add(self, entries: List, diversity_score=1.0): insert = True for e in new_optim_entries: if ( - entry == e + entry.last_entry == e.last_entry or tanimoto_dist_func( entry.last_entry.fingerprint, e.last_entry.fingerprint ) From 1306842f16ab726d55ee236fd5dd705996bae315 Mon Sep 17 00:00:00 2001 From: philipp guevorguian Date: Wed, 22 May 2024 00:14:46 +0000 Subject: [PATCH 27/45] pre merge --- chemlactica/mol_opt/optimization.py | 248 ++++++++++++++-------------- chemlactica/mol_opt/utils.py | 4 +- 2 files changed, 127 insertions(+), 125 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 0e5e530..cb644a7 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -64,132 +64,132 @@ def optimize( oracle, config, additional_properties={} ): - file = open(config["log_dir"], "w") - print("config", config) - # print("molecule generation arguments", config["generation_config"]) - pool = Pool(config["pool_size"], validation_perc=config["validation_perc"]) - - max_score = 0 - tol_level = 0 - num_iter = 0 - prev_train_iter = 0 - while True: - model.eval() - iter_optim_entries: List[OptimEntry] = [] - while len(iter_optim_entries) < config["num_gens_per_iter"]: - optim_entries = create_optimization_entries( - config["num_gens_per_iter"], pool, - config=config - ) - for i in range(len(optim_entries)): - last_entry = MoleculeEntry(smiles="") - last_entry.similar_mol_entries = create_similar_mol_entries( - pool, last_entry, config["num_similars"] + with open(config["log_dir"], "w") as file: + print("config", config) + # print("molecule generation arguments", config["generation_config"]) + pool = Pool(config["pool_size"], validation_perc=config["validation_perc"]) + + max_score = 0 + tol_level = 0 + num_iter = 0 + prev_train_iter = 0 + while True: + model.eval() + iter_optim_entries: List[OptimEntry] = [] + while len(iter_optim_entries) < config["num_gens_per_iter"]: + optim_entries = create_optimization_entries( + config["num_gens_per_iter"], pool, + config=config ) - for prop_name, prop_spec in additional_properties.items(): - last_entry.add_props[prop_name] = prop_spec - optim_entries[i].last_entry = last_entry - - prompts = [ - optim_entry.to_prompt(is_generation=True, include_oracle_score=prev_train_iter != 0, config=config) - for optim_entry in optim_entries - ] - output_texts = [] - for i in range(0, len(prompts), config["generation_batch_size"]): - prompt_batch = prompts[i: min(len(prompts), i + config["generation_batch_size"])] - data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device) - if type(model) == OPTForCausalLM: - del data["token_type_ids"] - for key, value in data.items(): - data[key] = value[:, -2048 + config["generation_config"]["max_new_tokens"]:] - output = model.generate( - **data, - **config["generation_config"] - ) - gc.collect() - torch.cuda.empty_cache() - output_texts.extend(tokenizer.batch_decode(output)) - - current_mol_entries = [] - current_optim_entries = [] - with multiprocessing.Pool(processes=config["num_processes"]) as pol: - for i, entry in enumerate(pol.map(create_molecule_entry, output_texts)): - if entry and not optim_entries[i].contains_entry(entry): - current_mol_entries.append(entry) - current_optim_entries.append(optim_entries[i]) - - if getattr(oracle, "takes_entry", False): - oracle_scores = oracle(current_mol_entries) - else: - oracle_scores = oracle([e.smiles for e in current_mol_entries]) - for i, oracle_score in enumerate(oracle_scores): - entry = current_mol_entries[i] - entry.score = oracle_score - entry.similar_mol_entries = current_optim_entries[i].last_entry.similar_mol_entries - for prop_name, prop_spec in additional_properties.items(): - entry.add_props[prop_name] = prop_spec - entry.add_props[prop_name]["value"] = entry.add_props[prop_name]["calculate_value"](entry) - current_optim_entries[i].last_entry = entry - iter_optim_entries.append(current_optim_entries[i]) - file.write(f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n") - if entry.score > max_score: - max_score = entry.score - tol_level = 0 - if oracle.finish or len(iter_optim_entries) >= config["num_gens_per_iter"]: + for i in range(len(optim_entries)): + last_entry = MoleculeEntry(smiles="") + last_entry.similar_mol_entries = create_similar_mol_entries( + pool, last_entry, config["num_similars"] + ) + for prop_name, prop_spec in additional_properties.items(): + last_entry.add_props[prop_name] = prop_spec + optim_entries[i].last_entry = last_entry + + prompts = [ + optim_entry.to_prompt(is_generation=True, include_oracle_score=prev_train_iter != 0, config=config) + for optim_entry in optim_entries + ] + output_texts = [] + for i in range(0, len(prompts), config["generation_batch_size"]): + prompt_batch = prompts[i: min(len(prompts), i + config["generation_batch_size"])] + data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device) + if type(model) == OPTForCausalLM: + del data["token_type_ids"] + for key, value in data.items(): + data[key] = value[:, -2048 + config["generation_config"]["max_new_tokens"]:] + output = model.generate( + **data, + **config["generation_config"] + ) + gc.collect() + torch.cuda.empty_cache() + output_texts.extend(tokenizer.batch_decode(output)) + + current_mol_entries = [] + current_optim_entries = [] + with multiprocessing.Pool(processes=config["num_processes"]) as pol: + for i, entry in enumerate(pol.map(create_molecule_entry, output_texts)): + if entry and not optim_entries[i].contains_entry(entry): + current_mol_entries.append(entry) + current_optim_entries.append(optim_entries[i]) + + if getattr(oracle, "takes_entry", False): + oracle_scores = oracle(current_mol_entries) + else: + oracle_scores = oracle([e.smiles for e in current_mol_entries]) + for i, oracle_score in enumerate(oracle_scores): + entry = current_mol_entries[i] + entry.score = oracle_score + entry.similar_mol_entries = current_optim_entries[i].last_entry.similar_mol_entries + for prop_name, prop_spec in additional_properties.items(): + entry.add_props[prop_name] = prop_spec + entry.add_props[prop_name]["value"] = entry.add_props[prop_name]["calculate_value"](entry) + current_optim_entries[i].last_entry = entry + iter_optim_entries.append(current_optim_entries[i]) + file.write(f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n") + if entry.score > max_score: + max_score = entry.score + tol_level = 0 + if oracle.finish or len(iter_optim_entries) >= config["num_gens_per_iter"]: + break + + if oracle.finish: break if oracle.finish: break - - if oracle.finish: - break - initial_num_iter = num_iter - num_iter = len(oracle.mol_buffer) // config["num_gens_per_iter"] - if num_iter > initial_num_iter: - tol_level += 1 - print(f"num_iter: {num_iter}, tol_level: {tol_level}, prev_train_iter: {prev_train_iter}") - - # diversity_score = 1 / (1 + math.log(1 + repeated_max_score) / math.log(10)) - pool.add(iter_optim_entries) - file.write("Pool\n") - for i, optim_entry in enumerate(pool.optim_entries): - file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") - - if "rej-sample-v2" in config["strategy"]: - # round_entries.extend(current_entries) - # round_entries = list(np.unique(round_entries))[::-1] - # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) - # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: - if config["rej_sample_config"]["should_train"](num_iter, tol_level, prev_train_iter): - train_entries, validation_entries = pool.get_train_valid_entries() - print(f"Num of training examples: {len(train_entries)}, num of validation examples: {len(validation_entries)}.") - file.write("Training entries\n") - for i, optim_entry in enumerate(train_entries): - file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") - file.write("Validation entries\n") - for i, optim_entry in enumerate(validation_entries): - file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") - - train_dataset = Dataset.from_dict({ - "sample": [ - optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config) - for optim_entry in train_entries - ] - }) - validation_dataset = Dataset.from_dict({ - "sample": [ - optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config) - for optim_entry in validation_entries - ] - }) - train_dataset.shuffle(seed=42) - validation_dataset.shuffle(seed=42) - config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"] - supervised_fine_tune( - model, tokenizer, - train_dataset, validation_dataset, - config["rej_sample_config"] - ) - gc.collect() - torch.cuda.empty_cache() - prev_train_iter = num_iter \ No newline at end of file + initial_num_iter = num_iter + num_iter = len(oracle.mol_buffer) // config["num_gens_per_iter"] + if num_iter > initial_num_iter: + tol_level += 1 + print(f"num_iter: {num_iter}, tol_level: {tol_level}, prev_train_iter: {prev_train_iter}") + + # diversity_score = 1 / (1 + math.log(1 + repeated_max_score) / math.log(10)) + pool.add(iter_optim_entries) + file.write("Pool\n") + for i, optim_entry in enumerate(pool.optim_entries): + file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") + + if "rej-sample-v2" in config["strategy"]: + # round_entries.extend(current_entries) + # round_entries = list(np.unique(round_entries))[::-1] + # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) + # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: + if config["rej_sample_config"]["should_train"](num_iter, tol_level, prev_train_iter): + train_entries, validation_entries = pool.get_train_valid_entries() + print(f"Num of training examples: {len(train_entries)}, num of validation examples: {len(validation_entries)}.") + file.write("Training entries\n") + for i, optim_entry in enumerate(train_entries): + file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") + file.write("Validation entries\n") + for i, optim_entry in enumerate(validation_entries): + file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") + + train_dataset = Dataset.from_dict({ + "sample": [ + optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config) + for optim_entry in train_entries + ] + }) + validation_dataset = Dataset.from_dict({ + "sample": [ + optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config) + for optim_entry in validation_entries + ] + }) + train_dataset.shuffle(seed=42) + validation_dataset.shuffle(seed=42) + config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"] + supervised_fine_tune( + model, tokenizer, + train_dataset, validation_dataset, + config["rej_sample_config"] + ) + gc.collect() + torch.cuda.empty_cache() + prev_train_iter = num_iter diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 256b4c2..23d7fcc 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -80,6 +80,8 @@ def __str__(self): def __repr__(self): return str(self) + def __hash__(self): + return hash(self.smiles) class Pool: @@ -94,7 +96,7 @@ def __init__(self, size, validation_perc: float): # self.molecule_entries.pop(rand_ind) # print(f"Dump {num} random elements from pool, num pool mols {len(self)}") - def add(self, entries: List, diversity_score=1.0): + def add(self, entries: List, diversity_score=1.r9): assert type(entries) == list self.optim_entries.extend(entries) self.optim_entries.sort(key=lambda x: x.last_entry, reverse=True) From b98a2cb2b2bea3f32598af87fe5db3488adb9ace Mon Sep 17 00:00:00 2001 From: tigranfah Date: Wed, 22 May 2024 09:02:48 +0000 Subject: [PATCH 28/45] add validation batch size, to avoid memory error --- chemlactica/mol_opt/tunning.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py index d04c639..da587ac 100644 --- a/chemlactica/mol_opt/tunning.py +++ b/chemlactica/mol_opt/tunning.py @@ -65,6 +65,7 @@ def supervised_fine_tune( training_args = TrainingArguments( output_dir=config["checkpoints_dir"], per_device_train_batch_size=config["train_batch_size"], + per_device_eval_batch_size=config["train_batch_size"], max_grad_norm=config["global_gradient_norm"], num_train_epochs=config["num_train_epochs"], evaluation_strategy="epoch", From 32cc5aa012047e2779ede99bf18675b7efc33ea9 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Wed, 22 May 2024 09:08:45 +0000 Subject: [PATCH 29/45] small fix --- chemlactica/mol_opt/optimization.py | 3 ++- chemlactica/mol_opt/utils.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index cb644a7..407bc7c 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -106,9 +106,10 @@ def optimize( **data, **config["generation_config"] ) + output_texts.extend(tokenizer.batch_decode(output)) + del output gc.collect() torch.cuda.empty_cache() - output_texts.extend(tokenizer.batch_decode(output)) current_mol_entries = [] current_optim_entries = [] diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 04d0953..3908902 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -96,7 +96,7 @@ def __init__(self, size, validation_perc: float): # self.molecule_entries.pop(rand_ind) # print(f"Dump {num} random elements from pool, num pool mols {len(self)}") - def add(self, entries: List, diversity_score=1.r9): + def add(self, entries: List, diversity_score=1.0): assert type(entries) == list self.optim_entries.extend(entries) self.optim_entries.sort(key=lambda x: x.last_entry, reverse=True) From 0ac3a7be962e741d35d4f33c99072a5ae77d88bf Mon Sep 17 00:00:00 2001 From: tigranfah Date: Sat, 25 May 2024 14:16:32 +0400 Subject: [PATCH 30/45] small fixes --- chemlactica/mol_opt/metrics.py | 51 ++++++ chemlactica/mol_opt/optimization.py | 266 +++++++++++++++------------- chemlactica/mol_opt/utils.py | 30 ++-- 3 files changed, 208 insertions(+), 139 deletions(-) create mode 100644 chemlactica/mol_opt/metrics.py diff --git a/chemlactica/mol_opt/metrics.py b/chemlactica/mol_opt/metrics.py new file mode 100644 index 0000000..4354519 --- /dev/null +++ b/chemlactica/mol_opt/metrics.py @@ -0,0 +1,51 @@ +import numpy as np +import torch + + +def average_agg_tanimoto(stock_vecs, gen_vecs, + batch_size=5000, agg='max', + device='cpu', p=1): + """ + For each molecule in gen_vecs finds closest molecule in stock_vecs. + Returns average tanimoto score for between these molecules + + Parameters: + stock_vecs: numpy array + gen_vecs: numpy array + agg: max or mean + p: power for averaging: (mean x^p)^(1/p) + """ + assert agg in ['max', 'mean'], "Can aggregate only max or mean" + agg_tanimoto = np.zeros(len(gen_vecs)) + total = np.zeros(len(gen_vecs)) + for j in range(0, stock_vecs.shape[0], batch_size): + x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float() + for i in range(0, gen_vecs.shape[0], batch_size): + y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float() + y_gen = y_gen.transpose(0, 1) + tp = torch.mm(x_stock, y_gen) + jac = (tp / (x_stock.sum(1, keepdim=True) + + y_gen.sum(0, keepdim=True) - tp)).cpu().numpy() + jac[np.isnan(jac)] = 1 + if p != 1: + jac = jac**p + if agg == 'max': + agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum( + agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0)) + elif agg == 'mean': + agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0) + total[i:i + y_gen.shape[1]] += jac.shape[0] + if agg == 'mean': + agg_tanimoto /= total + if p != 1: + agg_tanimoto = (agg_tanimoto)**(1/p) + return np.mean(agg_tanimoto) + + +def internal_diversity(molecule_fingerprints, device='cpu', fp_type='morgan', p=1): + """ + Computes internal diversity as: + 1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y)) + """ + return 1 - (average_agg_tanimoto(molecule_fingerprints, molecule_fingerprints, + agg='mean', device=device, p=p)).mean() \ No newline at end of file diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 407bc7c..c4f2357 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -64,133 +64,147 @@ def optimize( oracle, config, additional_properties={} ): - with open(config["log_dir"], "w") as file: - print("config", config) - # print("molecule generation arguments", config["generation_config"]) - pool = Pool(config["pool_size"], validation_perc=config["validation_perc"]) - - max_score = 0 - tol_level = 0 - num_iter = 0 - prev_train_iter = 0 - while True: - model.eval() - iter_optim_entries: List[OptimEntry] = [] - while len(iter_optim_entries) < config["num_gens_per_iter"]: - optim_entries = create_optimization_entries( - config["num_gens_per_iter"], pool, - config=config + file = open(config["log_dir"], "w") + print("config", config) + # print("molecule generation arguments", config["generation_config"]) + pool = Pool(config["pool_size"], validation_perc=config["validation_perc"]) + + config["generation_config"]["temperature"] = config["generation_temperature"][0] + max_score = 0 + tol_level = 0 + num_iter = 0 + prev_train_iter = 0 + while True: + model.eval() + new_best_molecule_generated = False + iter_unique_optim_entries: List[OptimEntry] = {} + while len(iter_unique_optim_entries) < config["num_gens_per_iter"]: + optim_entries = create_optimization_entries( + config["num_gens_per_iter"], pool, + config=config + ) + for i in range(len(optim_entries)): + last_entry = MoleculeEntry(smiles="") + last_entry.similar_mol_entries = create_similar_mol_entries( + pool, last_entry, config["num_similars"] ) - for i in range(len(optim_entries)): - last_entry = MoleculeEntry(smiles="") - last_entry.similar_mol_entries = create_similar_mol_entries( - pool, last_entry, config["num_similars"] - ) - for prop_name, prop_spec in additional_properties.items(): - last_entry.add_props[prop_name] = prop_spec - optim_entries[i].last_entry = last_entry - - prompts = [ - optim_entry.to_prompt(is_generation=True, include_oracle_score=prev_train_iter != 0, config=config) - for optim_entry in optim_entries - ] - output_texts = [] - for i in range(0, len(prompts), config["generation_batch_size"]): - prompt_batch = prompts[i: min(len(prompts), i + config["generation_batch_size"])] - data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device) - if type(model) == OPTForCausalLM: - del data["token_type_ids"] - for key, value in data.items(): - data[key] = value[:, -2048 + config["generation_config"]["max_new_tokens"]:] - output = model.generate( - **data, - **config["generation_config"] - ) - output_texts.extend(tokenizer.batch_decode(output)) - del output - gc.collect() - torch.cuda.empty_cache() - - current_mol_entries = [] - current_optim_entries = [] - with multiprocessing.Pool(processes=config["num_processes"]) as pol: - for i, entry in enumerate(pol.map(create_molecule_entry, output_texts)): - if entry and not optim_entries[i].contains_entry(entry): - current_mol_entries.append(entry) - current_optim_entries.append(optim_entries[i]) - - if getattr(oracle, "takes_entry", False): - oracle_scores = oracle(current_mol_entries) - else: - oracle_scores = oracle([e.smiles for e in current_mol_entries]) - for i, oracle_score in enumerate(oracle_scores): - entry = current_mol_entries[i] - entry.score = oracle_score - entry.similar_mol_entries = current_optim_entries[i].last_entry.similar_mol_entries - for prop_name, prop_spec in additional_properties.items(): - entry.add_props[prop_name] = prop_spec - entry.add_props[prop_name]["value"] = entry.add_props[prop_name]["calculate_value"](entry) - current_optim_entries[i].last_entry = entry - iter_optim_entries.append(current_optim_entries[i]) - file.write(f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n") - if entry.score > max_score: - max_score = entry.score - tol_level = 0 - if oracle.finish or len(iter_optim_entries) >= config["num_gens_per_iter"]: - break - - if oracle.finish: - break + for prop_name, prop_spec in additional_properties.items(): + last_entry.add_props[prop_name] = prop_spec + optim_entries[i].last_entry = last_entry + + prompts = [ + optim_entry.to_prompt( + is_generation=True, include_oracle_score=prev_train_iter != 0, + config=config, max_score=max_score + ) + for optim_entry in optim_entries + ] + output_texts = [] + for i in range(0, len(prompts), config["generation_batch_size"]): + prompt_batch = prompts[i: min(len(prompts), i + config["generation_batch_size"])] + data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device) + if type(model) == OPTForCausalLM: + del data["token_type_ids"] + for key, value in data.items(): + data[key] = value[:, -2048 + config["generation_config"]["max_new_tokens"]:] + output = model.generate( + **data, + **config["generation_config"] + ) + gc.collect() + torch.cuda.empty_cache() + output_texts.extend(tokenizer.batch_decode(output)) + + current_unique_optim_entries = {} + with multiprocessing.Pool(processes=config["num_processes"]) as pol: + for i, molecule in enumerate(pol.map(create_molecule_entry, output_texts)): + if molecule and not optim_entries[i].contains_entry(molecule): + if molecule.smiles not in oracle.mol_buffer and molecule.smiles not in current_unique_optim_entries: + molecule.similar_mol_entries = optim_entries[i].last_entry.similar_mol_entries + for prop_name, prop_spec in additional_properties.items(): + molecule.add_props[prop_name] = prop_spec + molecule.add_props[prop_name]["value"] = molecule.add_props[prop_name]["calculate_value"](molecule) + optim_entries[i].last_entry = molecule + current_unique_optim_entries[molecule.smiles] = optim_entries[i] + + num_of_molecules_to_score = min(len(current_unique_optim_entries), config["num_gens_per_iter"] - len(iter_unique_optim_entries)) + current_unique_smiles_list = list(current_unique_optim_entries.keys())[:num_of_molecules_to_score] + current_unique_optim_entries = {smiles: current_unique_optim_entries[smiles] for smiles in current_unique_smiles_list} + + if getattr(oracle, "takes_entry", False): + oracle_scores = oracle([current_unique_optim_entries[smiles].last_entry for smiles in current_unique_smiles_list]) + else: + oracle_scores = oracle(current_unique_smiles_list) + + for smiles, oracle_score in zip(current_unique_smiles_list, oracle_scores): + current_unique_optim_entries[smiles].last_entry.score = oracle_score + iter_unique_optim_entries[smiles] = current_unique_optim_entries[smiles] + file.write(f"generated smiles: {smiles}, score: {current_unique_optim_entries[smiles].last_entry.score:.4f}\n") + if current_unique_optim_entries[smiles].last_entry.score > max_score: + max_score = current_unique_optim_entries[smiles].last_entry.score + new_best_molecule_generated = True + + print(f"Iter unique optim entries: {len(iter_unique_optim_entries)}, budget: {len(oracle)}") if oracle.finish: break - initial_num_iter = num_iter - num_iter = len(oracle.mol_buffer) // config["num_gens_per_iter"] - if num_iter > initial_num_iter: - tol_level += 1 - print(f"num_iter: {num_iter}, tol_level: {tol_level}, prev_train_iter: {prev_train_iter}") - - # diversity_score = 1 / (1 + math.log(1 + repeated_max_score) / math.log(10)) - pool.add(iter_optim_entries) - file.write("Pool\n") - for i, optim_entry in enumerate(pool.optim_entries): - file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") - - if "rej-sample-v2" in config["strategy"]: - # round_entries.extend(current_entries) - # round_entries = list(np.unique(round_entries))[::-1] - # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) - # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: - if config["rej_sample_config"]["should_train"](num_iter, tol_level, prev_train_iter): - train_entries, validation_entries = pool.get_train_valid_entries() - print(f"Num of training examples: {len(train_entries)}, num of validation examples: {len(validation_entries)}.") - file.write("Training entries\n") - for i, optim_entry in enumerate(train_entries): - file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") - file.write("Validation entries\n") - for i, optim_entry in enumerate(validation_entries): - file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") - - train_dataset = Dataset.from_dict({ - "sample": [ - optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config) - for optim_entry in train_entries - ] - }) - validation_dataset = Dataset.from_dict({ - "sample": [ - optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config) - for optim_entry in validation_entries - ] - }) - train_dataset.shuffle(seed=42) - validation_dataset.shuffle(seed=42) - config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"] - supervised_fine_tune( - model, tokenizer, - train_dataset, validation_dataset, - config["rej_sample_config"] - ) - gc.collect() - torch.cuda.empty_cache() - prev_train_iter = num_iter + + if oracle.finish: + break + initial_num_iter = num_iter + num_iter = len(oracle.mol_buffer) // config["num_gens_per_iter"] + if num_iter > initial_num_iter: + tol_level += 1 + + if new_best_molecule_generated: + tol_level = 0 + + print(f"num_iter: {num_iter}, tol_level: {tol_level}, prev_train_iter: {prev_train_iter}") + if num_iter != initial_num_iter: + config["generation_config"]["temperature"] += config["num_gens_per_iter"] / (oracle.budget - config["num_gens_per_iter"]) * (config["generation_temperature"][1] - config["generation_temperature"][0]) + print(f"Generation temperature: {config['generation_config']['temperature']}") + + # diversity_score = 1 / (1 + math.log(1 + repeated_max_score) / math.log(10)) + pool.add(list(iter_unique_optim_entries.values())) + file.write("Pool\n") + for i, optim_entry in enumerate(pool.optim_entries): + file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") + + if "rej-sample-v2" in config["strategy"]: + # round_entries.extend(current_entries) + # round_entries = list(np.unique(round_entries))[::-1] + # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) + # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: + if config["rej_sample_config"]["should_train"](num_iter, tol_level, prev_train_iter): + train_entries, validation_entries = pool.get_train_valid_entries() + print(f"Num of training examples: {len(train_entries)}, num of validation examples: {len(validation_entries)}.") + file.write("Training entries\n") + for i, optim_entry in enumerate(train_entries): + file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") + file.write("Validation entries\n") + for i, optim_entry in enumerate(validation_entries): + file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") + + train_dataset = Dataset.from_dict({ + "sample": [ + optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config) + for optim_entry in train_entries + ] + }) + validation_dataset = Dataset.from_dict({ + "sample": [ + optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config) + for optim_entry in validation_entries + ] + }) + train_dataset.shuffle(seed=42) + validation_dataset.shuffle(seed=42) + config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"] + supervised_fine_tune( + model, tokenizer, + train_dataset, validation_dataset, + config["rej_sample_config"] + ) + gc.collect() + torch.cuda.empty_cache() + prev_train_iter = num_iter \ No newline at end of file diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 3908902..4718a6d 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -49,7 +49,8 @@ def generate_random_number(lower, upper): def canonicalize(smiles): - return Chem.MolToSmiles(Chem.MolFromSmiles(smiles), canonical=True) + mol = Chem.MolFromSmiles(smiles) + return Chem.MolToSmiles(mol, canonical=True) # return Chem.MolToSmiles(Chem.MolFromSmiles(smiles), kekuleSmiles=True) @@ -194,11 +195,12 @@ def __init__(self, last_entry, mol_entries): def to_prompt( self, is_generation: bool, include_oracle_score: bool, config, + max_score=None ): prompt = "" - prompt = config["eos_token"] + # prompt = config["eos_token"] for mol_entry in self.mol_entries: - # prompt += config["eos_token"] + prompt += config["eos_token"] prompt += create_prompt_with_similars(mol_entry=mol_entry) for prop_name, prop_spec in mol_entry.add_props.items(): @@ -214,7 +216,7 @@ def to_prompt( prompt += f"[START_SMILES]{mol_entry.smiles}[END_SMILES]" assert self.last_entry - # prompt += config["eos_token"] + prompt += config["eos_token"] if is_generation: prompt_with_similars = create_prompt_with_similars( self.last_entry, sim_range=config["sim_range"] @@ -231,15 +233,16 @@ def to_prompt( pass elif "rej-sample-v2" in config["strategy"]: if is_generation: - oracle_scores_of_mols_in_prompt = [e.score for e in self.mol_entries] - q_0_9 = ( - np.quantile(oracle_scores_of_mols_in_prompt, 0.9) - if oracle_scores_of_mols_in_prompt - else 0 - ) - desired_oracle_score = generate_random_number( - q_0_9, config["max_possible_oracle_score"] - ) + # oracle_scores_of_mols_in_prompt = [e.score for e in self.mol_entries] + # q_0_9 = ( + # np.quantile(oracle_scores_of_mols_in_prompt, 0.9) + # if oracle_scores_of_mols_in_prompt + # else 0 + # ) + # desired_oracle_score = generate_random_number( + # q_0_9, config["max_possible_oracle_score"] + # ) + desired_oracle_score = max_score oracle_score = desired_oracle_score else: oracle_score = self.last_entry.score @@ -252,6 +255,7 @@ def to_prompt( prompt += "[START_SMILES]" else: prompt += f"[START_SMILES]{self.last_entry.smiles}[END_SMILES]" + prompt += config["eos_token"] return prompt From d8ac65ff693feaa498c2127f50b774d52345c9f5 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Wed, 29 May 2024 13:32:36 +0400 Subject: [PATCH 31/45] add hparam config --- .../mol_opt/chemlactica_125m_hparams.yaml | 40 ++++++++++ chemlactica/mol_opt/hparam_search.py | 55 +++++++++++++ chemlactica/mol_opt/hparams_tune.yaml | 18 +++++ chemlactica/mol_opt/metrics.py | 20 +++++ chemlactica/mol_opt/optimization.py | 2 +- chemlactica/mol_opt/slurm_hparam_search.py | 79 +++++++++++++++++++ chemlactica/mol_opt/utils.py | 75 +++++++++++++++++- 7 files changed, 287 insertions(+), 2 deletions(-) create mode 100644 chemlactica/mol_opt/chemlactica_125m_hparams.yaml create mode 100644 chemlactica/mol_opt/hparam_search.py create mode 100644 chemlactica/mol_opt/hparams_tune.yaml create mode 100644 chemlactica/mol_opt/slurm_hparam_search.py diff --git a/chemlactica/mol_opt/chemlactica_125m_hparams.yaml b/chemlactica/mol_opt/chemlactica_125m_hparams.yaml new file mode 100644 index 0000000..190557f --- /dev/null +++ b/chemlactica/mol_opt/chemlactica_125m_hparams.yaml @@ -0,0 +1,40 @@ +# checkpoint_path: /nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/9954e52e400b43d18d3a40f6/checkpoint-20480 +# checkpoint_path: /nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/9954e52e400b43d18d3a40f6/checkpoint-12288 +checkpoint_path: /nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/1f289ff103034364bd27e1c3/checkpoint-18000 +tokenizer_path: /auto/home/tigranfahradyan/RetMol/RetMol/chemlactica/ChemLacticaTokenizer66 +pool_size: 50 +validation_perc: 0.2 +num_mols: 0 +num_similars: 1 +num_gens_per_iter: 200 +device: cuda:0 +sim_range: [0.8, 0.9] +# qed_range: [0.5, 0.9] +num_processes: 8 +generation_batch_size: 200 +eos_token: "" +generation_temperature: [1.0, 1.5] + +generation_config: + repetition_penalty: 1.0 + max_new_tokens: 100 + do_sample: true + eos_token_id: 20 + +strategy: [default] + +rej_sample_config: + train_tol_level: 3 + checkpoints_dir: ./ + max_learning_rate: 0.00001 + train_batch_size: 2 + gradient_accumulation_steps: 8 + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.999 + warmup_steps: 0 + global_gradient_norm: 1.0 + dataloader_num_workers: 1 + max_seq_length: 2048 + num_train_epochs: 5 + packing: false \ No newline at end of file diff --git a/chemlactica/mol_opt/hparam_search.py b/chemlactica/mol_opt/hparam_search.py new file mode 100644 index 0000000..652bce5 --- /dev/null +++ b/chemlactica/mol_opt/hparam_search.py @@ -0,0 +1,55 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +import yaml +import datetime +import argparse +import os +from utils import ConstraedTPSAOracle +from typing import List +from chemlactica.mol_opt.optimization import optimize + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + + +def default_train_condition(num_iter, tol_level, prev_train_iter): + return num_iter - prev_train_iter >= 3 + + +def tolerance_train_condition(cur_tol_level, train_tol_level): + return cur_tol_level >= train_tol_level + + +def choose_train_condition(name): + return { + "default" : default_train_condition, + "tolerance": tolerance_train_condition + }[name] + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--run_name", type=str, required=False) + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--config_default", type=str, required=False, default="chemlactica/chemlactica_125m_hparams.yaml") + parser.add_argument("--n_runs", type=int, required=False, default=1) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_arguments() + config = yaml.safe_load(open(args.config_default)) + print(config) + + model = AutoModelForCausalLM.from_pretrained(config["checkpoint_path"], torch_dtype=torch.bfloat16).to(config["device"]) + tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_path"], padding_side="left") + + seeds = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31] + oracle = ConstraedTPSAOracle(max_oracle_calls=5000) + for seed in seeds[:args.n_runs]: + config["log_dir"] = os.path.join(args.output_dir, "results_tpsa+weight+num_rungs.log") + config["rej_sample_config"]["should_train"] = choose_train_condition("tolerance") + optimize( + model, tokenizer, + oracle, config + ) \ No newline at end of file diff --git a/chemlactica/mol_opt/hparams_tune.yaml b/chemlactica/mol_opt/hparams_tune.yaml new file mode 100644 index 0000000..71721f6 --- /dev/null +++ b/chemlactica/mol_opt/hparams_tune.yaml @@ -0,0 +1,18 @@ +name: chemlactica +method: grid +metric: + goal: maximize + name: avg_auc +parameters: + strategy: [[default]] + + pool_size: [10, 30, 50] + num_mols: [0, 1, 2, 3, 5] + num_similars: [0, 1, 2, 3, 5] + num_gens_per_iter: [200, 400, 600] + generation_temperature: [[1.0, 1.0], [1.5, 1.5], [1.0, 1.5]] + + # rej_sample_config: + # num_train_epochs: [1, 3, 5, 7, 9] + # train_tol_level: [1, 3, 5, 7, 9] + # max_learning_rate: [0.0001, 0.00001, 0.000001] \ No newline at end of file diff --git a/chemlactica/mol_opt/metrics.py b/chemlactica/mol_opt/metrics.py index 4354519..f063171 100644 --- a/chemlactica/mol_opt/metrics.py +++ b/chemlactica/mol_opt/metrics.py @@ -2,6 +2,26 @@ import torch +def top_auc(buffer, top_n, finish, freq_log, max_oracle_calls): + sum = 0 + prev = 0 + called = 0 + ordered_results = list(sorted(buffer.items(), key=lambda kv: kv[1][1], reverse=False)) + for idx in range(freq_log, min(len(buffer), max_oracle_calls), freq_log): + temp_result = ordered_results[:idx] + temp_result = list(sorted(temp_result, key=lambda kv: kv[1][0], reverse=True))[:top_n] + top_n_now = np.mean([item[1][0] for item in temp_result]) + sum += freq_log * (top_n_now + prev) / 2 + prev = top_n_now + called = idx + temp_result = list(sorted(ordered_results, key=lambda kv: kv[1][0], reverse=True))[:top_n] + top_n_now = np.mean([item[1][0] for item in temp_result]) + sum += (len(buffer) - called) * (top_n_now + prev) / 2 + if finish and len(buffer) < max_oracle_calls: + sum += (max_oracle_calls - len(buffer)) * top_n_now + return sum / max_oracle_calls + + def average_agg_tanimoto(stock_vecs, gen_vecs, batch_size=5000, agg='max', device='cpu', p=1): diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index c4f2357..a8de8c6 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -175,7 +175,7 @@ def optimize( # round_entries = list(np.unique(round_entries))[::-1] # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: - if config["rej_sample_config"]["should_train"](num_iter, tol_level, prev_train_iter): + if config["rej_sample_config"]["should_train"](tol_level, config["rej_sample_config"]["train_tol_level"]): train_entries, validation_entries = pool.get_train_valid_entries() print(f"Num of training examples: {len(train_entries)}, num of validation examples: {len(validation_entries)}.") file.write("Training entries\n") diff --git a/chemlactica/mol_opt/slurm_hparam_search.py b/chemlactica/mol_opt/slurm_hparam_search.py new file mode 100644 index 0000000..eaa775b --- /dev/null +++ b/chemlactica/mol_opt/slurm_hparam_search.py @@ -0,0 +1,79 @@ +import submitit +import subprocess +import itertools as it +import datetime +import yaml +import os +import copy + + +def create_hparam_configs(config_file_path): + config_tune = yaml.safe_load(open("hparams_tune.yaml")) + config_merged = {} + for key, value in config_tune["parameters"].items(): + if type(value) == list: + config_merged[key] = value + else: + for k, v in value.items(): + config_merged[key+'+'+k] = v + + config_default = yaml.safe_load(open(config_file_path)) + hparam_names = list(config_merged.keys()) + all_configs = [] + for params in it.product(*config_merged.values()): + # pprint(params) + # pprint(hparam_names) + config = copy.deepcopy(config_default) + for i, p in enumerate(params): + if '+' in hparam_names[i]: + a, b = hparam_names[i].split("+") + config[a][b] = p + else: + config[hparam_names[i]] = p + # pprint(params) + # pprint(config) + all_configs.append(config) + # print(config) + return all_configs + + +if __name__ == "__main__": + n_runs = 3 + + config_file_path = "chemlactica_125m_hparams.yaml" + # config_file_path = "main/chemlactica/chemma_2b_hparams.yaml" + hparam_configs = create_hparam_configs(config_file_path) + # infer_config = [yaml.safe_load(open(config_file_path))] + model_name = "-".join(config_file_path.split("/")[-1].split("_")[:2]) + + executor = submitit.AutoExecutor(folder="/auto/home/tigranfahradyan/slurm_jobs/PMO/job_%j") + executor.update_parameters( + name="chemlactica-pmo", timeout_min=n_runs * 3 * 60, + gpus_per_node=1, nodes=1, mem_gb=50, cpus_per_task=8, + slurm_array_parallelism=10 + ) + jobs = [] + with executor.batch(): + for config in hparam_configs[:1]: + formatted_date_time = datetime.datetime.now().strftime("%Y-%m-%d") + base = f"results/{formatted_date_time}" + os.makedirs(base, exist_ok=True) + v = 0 + name = model_name + "-" + "+".join(config["strategy"]) + while os.path.exists(os.path.join(base, f"{name}-{v}")): + v += 1 + output_dir = os.path.join(base, f"{name}-{v}") + output_dir += "tune" + # output_dir = "main/chemlactica/results/2024-05-11/chemlactica-125m-rej-sample-4" + os.makedirs(output_dir, exist_ok=True) + yaml.safe_dump(config, open(os.path.join(output_dir, "hparams.yaml"), "w")) + function = submitit.helpers.CommandFunction([ + 'python3', 'hparam_search.py', + '--config_default', os.path.join(output_dir, "hparams.yaml"), + '--output_dir', output_dir, + '--n_runs', str(n_runs), + ]) + print(' '.join(function.command)) + # subprocess.run(function.command) + job = executor.submit(function) + jobs.append(job) diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 4718a6d..cdcc627 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -5,8 +5,9 @@ from pathlib import Path import numpy as np import torch +from metrics import top_auc from rdkit import Chem, DataStructs, RDLogger -from rdkit.Chem import AllChem, MACCSkeys +from rdkit.Chem import AllChem, MACCSkeys, rdMolDescriptors # Disable RDKit logs RDLogger.DisableLog("rdApp.*") @@ -85,6 +86,78 @@ def __hash__(self): return hash(self.smiles) +class ConstraedTPSAOracle: + def __init__(self, max_oracle_calls: int): + self.max_oracle_calls = max_oracle_calls + self.freq_log = 100 + self.mol_buffer = {} + self.max_possible_oracle_score = 1.0 + self.takes_entry = True + + def __call__(self, molecules): + oracle_scores = [] + for molecule in molecules: + if self.mol_buffer.get(molecule.smiles): + oracle_scores.append(sum(self.mol_buffer[molecule.smiles][0])) + else: + try: + tpsa = rdMolDescriptors.CalcTPSA(molecule.mol) + tpsa_score = min(tpsa / 1000, 1) + weight = rdMolDescriptors.CalcExactMolWt(molecule.mol) + if weight <= 349: + weight_score = 1 + elif weight >= 500: + weight_score = 0 + else: + weight_score = -0.00662 * weight + 3.31125 + + num_rings = rdMolDescriptors.CalcNumRings(molecule.mol) + if num_rings >= 2: + num_rights_score = 1 + else: + num_rights_score = 0 + # print(tpsa_score, weight_score, num_rights_score) + oracle_score = (tpsa_score + weight_score + num_rights_score) / 3 + except Exception as e: + print(e) + oracle_score = 0 + self.mol_buffer[molecule.smiles] = [oracle_score, len(self.mol_buffer) + 1] + if len(self.mol_buffer) % 100 == 0: + self.log_intermediate() + oracle_scores.append(oracle_score) + return oracle_scores + + def log_intermediate(self): + scores = [v[0] for v in self.mol_buffer.values()] + scores_sorted = sorted(scores, reverse=True)[:100] + n_calls = len(self.mol_buffer) + + score_avg_top1 = np.max(scores_sorted) + score_avg_top10 = np.mean(scores_sorted[:10]) + score_avg_top100 = np.mean(scores_sorted) + + print(f"{n_calls}/{self.max_oracle_calls} | ", + f"auc_top1: {top_auc(self.mol_buffer, 1, False, self.freq_log, self.max_oracle_calls)} | ", + f"auc_top10: {top_auc(self.mol_buffer, 10, False, self.freq_log, self.max_oracle_calls)} | ", + f"auc_top100: {top_auc(self.mol_buffer, 100, False, self.freq_log, self.max_oracle_calls)}") + + print(f'avg_top1: {score_avg_top1:.3f} | ' + f'avg_top10: {score_avg_top10:.3f} | ' + f'avg_top100: {score_avg_top100:.3f}') + + def __len__(self): + return len(self.mol_buffer) + + @property + def budget(self): + return self.max_oracle_calls + + @property + def finish(self): + return len(self.mol_buffer) >= self.max_oracle_calls + + + class Pool: def __init__(self, size, validation_perc: float): self.size = size From 05af82d178d6acb707e2dfa48c15dc3d51d8e4a2 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Wed, 29 May 2024 11:11:47 +0000 Subject: [PATCH 32/45] add no slurm parallel runs to mol_opt --- .gitignore | 2 + .../mol_opt/chemlactica_125m_hparams.yaml | 4 +- chemlactica/mol_opt/hparam_search.py | 2 +- chemlactica/mol_opt/no_slurm_hparam_search.py | 119 ++++++++++++++++++ 4 files changed, 124 insertions(+), 3 deletions(-) create mode 100644 chemlactica/mol_opt/no_slurm_hparam_search.py diff --git a/.gitignore b/.gitignore index 5919610..2b2f680 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +chemlactica/mol_opt/results/ + # Aim experiment metadata aim/ diff --git a/chemlactica/mol_opt/chemlactica_125m_hparams.yaml b/chemlactica/mol_opt/chemlactica_125m_hparams.yaml index 190557f..f9b8ad3 100644 --- a/chemlactica/mol_opt/chemlactica_125m_hparams.yaml +++ b/chemlactica/mol_opt/chemlactica_125m_hparams.yaml @@ -1,7 +1,7 @@ # checkpoint_path: /nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/9954e52e400b43d18d3a40f6/checkpoint-20480 # checkpoint_path: /nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/9954e52e400b43d18d3a40f6/checkpoint-12288 -checkpoint_path: /nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/1f289ff103034364bd27e1c3/checkpoint-18000 -tokenizer_path: /auto/home/tigranfahradyan/RetMol/RetMol/chemlactica/ChemLacticaTokenizer66 +checkpoint_path: /home/admin/checkpoints/facebook/galactica-125m/1f289ff103034364bd27e1c3/checkpoint-18000 +tokenizer_path: /home/admin/tigran/ChemLactica/chemlactica/tokenizer/ChemLacticaTokenizer66 pool_size: 50 validation_perc: 0.2 num_mols: 0 diff --git a/chemlactica/mol_opt/hparam_search.py b/chemlactica/mol_opt/hparam_search.py index 652bce5..4730f78 100644 --- a/chemlactica/mol_opt/hparam_search.py +++ b/chemlactica/mol_opt/hparam_search.py @@ -45,7 +45,7 @@ def parse_arguments(): tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_path"], padding_side="left") seeds = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31] - oracle = ConstraedTPSAOracle(max_oracle_calls=5000) + oracle = ConstraedTPSAOracle(max_oracle_calls=15000) for seed in seeds[:args.n_runs]: config["log_dir"] = os.path.join(args.output_dir, "results_tpsa+weight+num_rungs.log") config["rej_sample_config"]["should_train"] = choose_train_condition("tolerance") diff --git a/chemlactica/mol_opt/no_slurm_hparam_search.py b/chemlactica/mol_opt/no_slurm_hparam_search.py new file mode 100644 index 0000000..219c0c9 --- /dev/null +++ b/chemlactica/mol_opt/no_slurm_hparam_search.py @@ -0,0 +1,119 @@ +import submitit +import subprocess +import itertools as it +import datetime +import yaml +import os +import copy +import time +import torch + + +def is_gpu_being_used(gpu_id): + try: + # Run the nvidia-smi command + cmd = ['nvidia-smi','-i',f"{gpu_id}"] + output = subprocess.check_output(cmd) + output = output.decode('utf-8') + if "No running processes found" in output: + return False + else: + return True + + except subprocess.CalledProcessError as e: + print(f"Error executing nvidia-smi command: {e}") + + +def create_hparam_configs(config_file_path): + config_tune = yaml.safe_load(open("hparams_tune.yaml")) + config_merged = {} + for key, value in config_tune["parameters"].items(): + if type(value) == list: + config_merged[key] = value + else: + for k, v in value.items(): + config_merged[key+'+'+k] = v + + config_default = yaml.safe_load(open(config_file_path)) + hparam_names = list(config_merged.keys()) + all_configs = [] + for params in it.product(*config_merged.values()): + # pprint(params) + # pprint(hparam_names) + config = copy.deepcopy(config_default) + for i, p in enumerate(params): + if '+' in hparam_names[i]: + a, b = hparam_names[i].split("+") + config[a][b] = p + else: + config[hparam_names[i]] = p + # pprint(params) + # pprint(config) + all_configs.append(config) + # print(config) + return all_configs + + +if __name__ == "__main__": + n_runs = 3 + + config_file_path = "chemlactica_125m_hparams.yaml" + # config_file_path = "main/chemlactica/chemma_2b_hparams.yaml" + hparam_configs = create_hparam_configs(config_file_path) + # infer_config = [yaml.safe_load(open(config_file_path))] + model_name = "-".join(config_file_path.split("/")[-1].split("_")[:2]) + gpu_indices = [0, 1, 2, 3, 4, 5, 6, 7] + + index = 0 + while index < len(hparam_configs): + free_gpu_index = None + for gpu_index in gpu_indices: + gpu_is_free = True + print(f"Checking gpu: {gpu_index}") + for _ in range(10): + if is_gpu_being_used(gpu_index): + gpu_is_free = False + break + time.sleep(1) + if gpu_is_free: + free_gpu_index = gpu_index + print(f"gpu: {gpu_index} is free") + break + else: + print(f"gpu: {gpu_index} is being used") + if free_gpu_index is not None: + print(f"found a free gpu {free_gpu_index}, putting a job") + executor = submitit.LocalExecutor(folder="/home/admin/tigran/slurm_jobs/PMO/job_%j") + executor.update_parameters( + name="chemlactica-pmo", timeout_min=n_runs * 12 * 60, + visible_gpus=[free_gpu_index], + gpus_per_node=1, nodes=1, mem_gb=80, cpus_per_task=8, + slurm_array_parallelism=10 + ) + jobs = [] + with executor.batch(): + current_hparams = [hparam_configs[index]] + for config in current_hparams: + formatted_date_time = datetime.datetime.now().strftime("%Y-%m-%d") + base = f"results/{formatted_date_time}" + os.makedirs(base, exist_ok=True) + v = 0 + name = model_name + "-" + "+".join(config["strategy"]) + while os.path.exists(os.path.join(base, f"{name}-{v}-hparam-search")): + v += 1 + output_dir = os.path.join(base, f"{name}-{v}-hparam-search") + os.makedirs(output_dir, exist_ok=True) + yaml.safe_dump(config, open(os.path.join(output_dir, "hparams.yaml"), "w")) + function = submitit.helpers.CommandFunction([ + 'python3', 'hparam_search.py', + '--config_default', os.path.join(output_dir, "hparams.yaml"), + '--output_dir', output_dir, + '--n_runs', str(n_runs), + ]) + print(' '.join(function.command)) + job = executor.submit(function) + jobs.append(job) + for job in jobs: + print(job.job_id) + index += 1 + free_gpu_index = None \ No newline at end of file From 16a00ca566183fece429f2c32263c976e7d69ad1 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Wed, 29 May 2024 22:21:29 +0400 Subject: [PATCH 33/45] keep optim state during the fine-tuning of mol optim process --- chemlactica/mol_opt/optimization.py | 31 +++++++++++++++++++++----- chemlactica/mol_opt/tunning.py | 34 +++++++---------------------- chemlactica/mol_opt/utils.py | 2 +- 3 files changed, 34 insertions(+), 33 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index a8de8c6..35992d4 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -7,10 +7,11 @@ import tqdm import random from functools import partial +from trl import SFTTrainer import numpy as np from transformers import OPTForCausalLM from chemlactica.mol_opt.utils import OptimEntry, MoleculeEntry, Pool, generate_random_number, tanimoto_dist_func -from chemlactica.mol_opt.tunning import supervised_fine_tune +from chemlactica.mol_opt.tunning import get_training_arguments, get_optimizer_and_lr_scheduler, CustomEarlyStopCallback def create_similar_mol_entries(pool, mol_entry, num_similars): @@ -70,6 +71,10 @@ def optimize( pool = Pool(config["pool_size"], validation_perc=config["validation_perc"]) config["generation_config"]["temperature"] = config["generation_temperature"][0] + + if "rej-sample-v2" in config["strategy"]: + training_args = get_training_arguments(config["rej_sample_config"]) + optimizer, lr_scheduler = get_optimizer_and_lr_scheduler(model, config["rej_sample_config"], config["pool_size"]) max_score = 0 tol_level = 0 num_iter = 0 @@ -199,12 +204,26 @@ def optimize( }) train_dataset.shuffle(seed=42) validation_dataset.shuffle(seed=42) - config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"] - supervised_fine_tune( - model, tokenizer, - train_dataset, validation_dataset, - config["rej_sample_config"] + + model.train() + early_stopping_callback = CustomEarlyStopCallback( + early_stopping_patience=1, + early_stopping_threshold=0.0001 + ) + trainer = SFTTrainer( + model=model, + train_dataset=train_dataset, + eval_dataset=validation_dataset, + formatting_func=lambda x: x["sample"], + args=training_args, + packing=config["rej_sample_config"]["packing"], + tokenizer=tokenizer, + max_seq_length=config["rej_sample_config"]["max_seq_length"], + # data_collator=collator, + optimizers=[optimizer, lr_scheduler], + callbacks=[early_stopping_callback], ) + trainer.train() gc.collect() torch.cuda.empty_cache() prev_train_iter = num_iter \ No newline at end of file diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py index da587ac..f5bdfcc 100644 --- a/chemlactica/mol_opt/tunning.py +++ b/chemlactica/mol_opt/tunning.py @@ -57,12 +57,8 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra # return super().log(logs) -def supervised_fine_tune( - model, tokenizer, - train_dataset, validation_dataset, config - ): - model.train() - training_args = TrainingArguments( +def get_training_arguments(config): + return TrainingArguments( output_dir=config["checkpoints_dir"], per_device_train_batch_size=config["train_batch_size"], per_device_eval_batch_size=config["train_batch_size"], @@ -76,6 +72,9 @@ def supervised_fine_tune( logging_steps=1, metric_for_best_model="loss" ) + + +def get_optimizer_and_lr_scheduler(model, config, tr_ds_size): optimizer = torch.optim.AdamW( model.parameters(), lr=config["max_learning_rate"], @@ -85,25 +84,8 @@ def supervised_fine_tune( lr_scheduler = get_polynomial_decay_schedule_with_warmup( optimizer, num_warmup_steps=config["warmup_steps"], - num_training_steps=config["num_train_epochs"] * (len(train_dataset) // config["train_batch_size"] + 1), - lr_end=0.999 * config["max_learning_rate"], + num_training_steps=config["num_train_epochs"] * (tr_ds_size // config["train_batch_size"] + 1), + lr_end=0.99999 * config["max_learning_rate"], power=1.0, ) - early_stopping_callback = CustomEarlyStopCallback( - early_stopping_patience=1, - early_stopping_threshold=0.001 - ) - trainer = SFTTrainer( - model=model, - train_dataset=train_dataset, - eval_dataset=validation_dataset, - formatting_func=config["formatting_func"], - args=training_args, - packing=config["packing"], - tokenizer=tokenizer, - max_seq_length=config["max_seq_length"], - # data_collator=collator, - optimizers=[optimizer, lr_scheduler], - callbacks=[early_stopping_callback], - ) - trainer.train() + return optimizer, lr_scheduler diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index cdcc627..e97d043 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -5,7 +5,7 @@ from pathlib import Path import numpy as np import torch -from metrics import top_auc +from chemlactica.mol_opt.metrics import top_auc from rdkit import Chem, DataStructs, RDLogger from rdkit.Chem import AllChem, MACCSkeys, rdMolDescriptors From 832906cded03e57bc8b0dffb5217e99347d72789 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Mon, 3 Jun 2024 23:29:26 +0400 Subject: [PATCH 34/45] refine --- chemlactica/mol_opt/optimization.py | 43 ++++++++++++++++------------- chemlactica/mol_opt/tunning.py | 16 +++++++---- 2 files changed, 35 insertions(+), 24 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 35992d4..28c98c3 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -6,6 +6,7 @@ import math import tqdm import random +import shutil from functools import partial from trl import SFTTrainer import numpy as np @@ -74,7 +75,11 @@ def optimize( if "rej-sample-v2" in config["strategy"]: training_args = get_training_arguments(config["rej_sample_config"]) - optimizer, lr_scheduler = get_optimizer_and_lr_scheduler(model, config["rej_sample_config"], config["pool_size"]) + effective_batch_size = config["rej_sample_config"]["gradient_accumulation_steps"] * config["rej_sample_config"]["train_batch_size"] + num_single_train_steps = config["rej_sample_config"]["num_train_epochs"] * ((1 - config["validation_perc"]) * config["pool_size"] / effective_batch_size) + max_num_trains = oracle.max_oracle_calls / (config["rej_sample_config"]["train_tol_level"] * config["num_gens_per_iter"]) + max_num_train_steps = int(max_num_trains * num_single_train_steps) + optimizer, lr_scheduler = get_optimizer_and_lr_scheduler(model, config["rej_sample_config"], max_num_train_steps) max_score = 0 tol_level = 0 num_iter = 0 @@ -85,7 +90,7 @@ def optimize( iter_unique_optim_entries: List[OptimEntry] = {} while len(iter_unique_optim_entries) < config["num_gens_per_iter"]: optim_entries = create_optimization_entries( - config["num_gens_per_iter"], pool, + config["generation_batch_size"], pool, config=config ) for i in range(len(optim_entries)): @@ -105,20 +110,18 @@ def optimize( for optim_entry in optim_entries ] output_texts = [] - for i in range(0, len(prompts), config["generation_batch_size"]): - prompt_batch = prompts[i: min(len(prompts), i + config["generation_batch_size"])] - data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device) - if type(model) == OPTForCausalLM: - del data["token_type_ids"] - for key, value in data.items(): - data[key] = value[:, -2048 + config["generation_config"]["max_new_tokens"]:] - output = model.generate( - **data, - **config["generation_config"] - ) - gc.collect() - torch.cuda.empty_cache() - output_texts.extend(tokenizer.batch_decode(output)) + data = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + if type(model) == OPTForCausalLM: + del data["token_type_ids"] + for key, value in data.items(): + data[key] = value[:, -2048 + config["generation_config"]["max_new_tokens"]:] + output = model.generate( + **data, + **config["generation_config"] + ) + gc.collect() + torch.cuda.empty_cache() + output_texts.extend(tokenizer.batch_decode(output)) current_unique_optim_entries = {} with multiprocessing.Pool(processes=config["num_processes"]) as pol: @@ -165,7 +168,7 @@ def optimize( tol_level = 0 print(f"num_iter: {num_iter}, tol_level: {tol_level}, prev_train_iter: {prev_train_iter}") - if num_iter != initial_num_iter: + if num_iter > initial_num_iter: config["generation_config"]["temperature"] += config["num_gens_per_iter"] / (oracle.budget - config["num_gens_per_iter"]) * (config["generation_temperature"][1] - config["generation_temperature"][0]) print(f"Generation temperature: {config['generation_config']['temperature']}") @@ -180,7 +183,7 @@ def optimize( # round_entries = list(np.unique(round_entries))[::-1] # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: - if config["rej_sample_config"]["should_train"](tol_level, config["rej_sample_config"]["train_tol_level"]): + if tol_level >= config["rej_sample_config"]["train_tol_level"]: train_entries, validation_entries = pool.get_train_valid_entries() print(f"Num of training examples: {len(train_entries)}, num of validation examples: {len(validation_entries)}.") file.write("Training entries\n") @@ -220,10 +223,12 @@ def optimize( tokenizer=tokenizer, max_seq_length=config["rej_sample_config"]["max_seq_length"], # data_collator=collator, - optimizers=[optimizer, lr_scheduler], callbacks=[early_stopping_callback], + optimizers=[optimizer, lr_scheduler], ) trainer.train() + shutil.rmtree(training_args.output_dir) gc.collect() torch.cuda.empty_cache() + tol_level = 0 prev_train_iter = num_iter \ No newline at end of file diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py index f5bdfcc..456b728 100644 --- a/chemlactica/mol_opt/tunning.py +++ b/chemlactica/mol_opt/tunning.py @@ -4,6 +4,8 @@ from torch.optim.lr_scheduler import ConstantLR import torch import math +import time +from chemlactica.mol_opt.utils import generate_random_number class CustomEarlyStopCallback(TrainerCallback): @@ -58,23 +60,27 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra def get_training_arguments(config): + checkpoints_dir = config["checkpoints_dir"] + "_" + str(time.time()) return TrainingArguments( - output_dir=config["checkpoints_dir"], + output_dir=checkpoints_dir, per_device_train_batch_size=config["train_batch_size"], per_device_eval_batch_size=config["train_batch_size"], max_grad_norm=config["global_gradient_norm"], num_train_epochs=config["num_train_epochs"], evaluation_strategy="epoch", + # save_strategy="epoch", dataloader_drop_last=False, dataloader_pin_memory=True, dataloader_num_workers=config["dataloader_num_workers"], gradient_accumulation_steps=config["gradient_accumulation_steps"], logging_steps=1, - metric_for_best_model="loss" + metric_for_best_model="loss", + # load_best_model_at_end=True, + # save_total_limit=1 ) -def get_optimizer_and_lr_scheduler(model, config, tr_ds_size): +def get_optimizer_and_lr_scheduler(model, config, max_train_steps): optimizer = torch.optim.AdamW( model.parameters(), lr=config["max_learning_rate"], @@ -84,8 +90,8 @@ def get_optimizer_and_lr_scheduler(model, config, tr_ds_size): lr_scheduler = get_polynomial_decay_schedule_with_warmup( optimizer, num_warmup_steps=config["warmup_steps"], - num_training_steps=config["num_train_epochs"] * (tr_ds_size // config["train_batch_size"] + 1), - lr_end=0.99999 * config["max_learning_rate"], + num_training_steps=max_train_steps, + lr_end=0, power=1.0, ) return optimizer, lr_scheduler From 1f8142fd340d7c8461c873de2feb1ccc8efbdf12 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Thu, 6 Jun 2024 12:10:32 +0400 Subject: [PATCH 35/45] add lr annealing/not annealing --- chemlactica/mol_opt/tunning.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py index 456b728..eb41f89 100644 --- a/chemlactica/mol_opt/tunning.py +++ b/chemlactica/mol_opt/tunning.py @@ -92,6 +92,7 @@ def get_optimizer_and_lr_scheduler(model, config, max_train_steps): num_warmup_steps=config["warmup_steps"], num_training_steps=max_train_steps, lr_end=0, + # lr_end=0.9999 * config["max_learning_rate"], power=1.0, ) return optimizer, lr_scheduler From 91d9390536b492c7da024fc94e3cb22cbce71bca Mon Sep 17 00:00:00 2001 From: tigranfah Date: Thu, 6 Jun 2024 19:16:52 +0400 Subject: [PATCH 36/45] add ending lr in config --- chemlactica/mol_opt/tunning.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py index eb41f89..fc7d0ed 100644 --- a/chemlactica/mol_opt/tunning.py +++ b/chemlactica/mol_opt/tunning.py @@ -91,8 +91,7 @@ def get_optimizer_and_lr_scheduler(model, config, max_train_steps): optimizer, num_warmup_steps=config["warmup_steps"], num_training_steps=max_train_steps, - lr_end=0, - # lr_end=0.9999 * config["max_learning_rate"], + lr_end=config["lr_end"], power=1.0, ) return optimizer, lr_scheduler From 2cf0f8bf942bdf371ad6c1928ede4e4783ab5dd1 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Fri, 7 Jun 2024 01:55:15 +0400 Subject: [PATCH 37/45] remove multiprocessing --- chemlactica/mol_opt/optimization.py | 20 ++++++++++---------- chemlactica/mol_opt/utils.py | 3 ++- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 28c98c3..7f8bd18 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -124,16 +124,16 @@ def optimize( output_texts.extend(tokenizer.batch_decode(output)) current_unique_optim_entries = {} - with multiprocessing.Pool(processes=config["num_processes"]) as pol: - for i, molecule in enumerate(pol.map(create_molecule_entry, output_texts)): - if molecule and not optim_entries[i].contains_entry(molecule): - if molecule.smiles not in oracle.mol_buffer and molecule.smiles not in current_unique_optim_entries: - molecule.similar_mol_entries = optim_entries[i].last_entry.similar_mol_entries - for prop_name, prop_spec in additional_properties.items(): - molecule.add_props[prop_name] = prop_spec - molecule.add_props[prop_name]["value"] = molecule.add_props[prop_name]["calculate_value"](molecule) - optim_entries[i].last_entry = molecule - current_unique_optim_entries[molecule.smiles] = optim_entries[i] + # with multiprocessing.Pool(processes=config["num_processes"]) as pol: + for i, molecule in enumerate(map(create_molecule_entry, output_texts)): + if molecule and not optim_entries[i].contains_entry(molecule): + if molecule.smiles not in oracle.mol_buffer and molecule.smiles not in current_unique_optim_entries: + molecule.similar_mol_entries = optim_entries[i].last_entry.similar_mol_entries + for prop_name, prop_spec in additional_properties.items(): + molecule.add_props[prop_name] = prop_spec + molecule.add_props[prop_name]["value"] = molecule.add_props[prop_name]["calculate_value"](molecule) + optim_entries[i].last_entry = molecule + current_unique_optim_entries[molecule.smiles] = optim_entries[i] num_of_molecules_to_score = min(len(current_unique_optim_entries), config["num_gens_per_iter"] - len(iter_unique_optim_entries)) current_unique_smiles_list = list(current_unique_optim_entries.keys())[:num_of_molecules_to_score] diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index e97d043..d0affd9 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -223,7 +223,8 @@ def get_train_valid_entries(self): return train_entries, valid_entries def random_subset(self, subset_size): - rand_inds = np.random.permutation(min(len(self.optim_entries), subset_size)) + rand_inds = np.random.permutation(len(self.optim_entries)) + rand_inds = rand_inds[:subset_size] return [self.optim_entries[i] for i in rand_inds] def __len__(self): From 9c9691955d56daaf75635ee1e2ce139aef948678 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Fri, 7 Jun 2024 19:49:57 +0400 Subject: [PATCH 38/45] change sampling from pool --- chemlactica/mol_opt/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index d0affd9..0c9d9cc 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -223,8 +223,7 @@ def get_train_valid_entries(self): return train_entries, valid_entries def random_subset(self, subset_size): - rand_inds = np.random.permutation(len(self.optim_entries)) - rand_inds = rand_inds[:subset_size] + rand_inds = np.random.permutation(min(len(self.optim_entries), subset_size * 2)) return [self.optim_entries[i] for i in rand_inds] def __len__(self): From 5d7a25844065784109902aad7c3cd296333b0732 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Sat, 8 Jun 2024 01:13:13 +0400 Subject: [PATCH 39/45] fix --- chemlactica/mol_opt/optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 7f8bd18..d2bf603 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -16,7 +16,7 @@ def create_similar_mol_entries(pool, mol_entry, num_similars): - similar_entries = [e.last_entry for e in pool.random_subset(num_similars + 1)] + similar_entries = [e.last_entry for e in pool.random_subset(num_similars)] count = 0 valid_similar_entries = [] for similar_entry in similar_entries: From be4096d391b8d6798b7d72f7bfe1294328f28306 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Sat, 8 Jun 2024 23:27:33 +0400 Subject: [PATCH 40/45] add model selection based on validation loss --- chemlactica/mol_opt/optimization.py | 7 +--- chemlactica/mol_opt/tunning.py | 51 +++++++++++++++-------------- chemlactica/mol_opt/utils.py | 4 ++- 3 files changed, 30 insertions(+), 32 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index d2bf603..7457614 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -12,7 +12,7 @@ import numpy as np from transformers import OPTForCausalLM from chemlactica.mol_opt.utils import OptimEntry, MoleculeEntry, Pool, generate_random_number, tanimoto_dist_func -from chemlactica.mol_opt.tunning import get_training_arguments, get_optimizer_and_lr_scheduler, CustomEarlyStopCallback +from chemlactica.mol_opt.tunning import get_training_arguments, get_optimizer_and_lr_scheduler def create_similar_mol_entries(pool, mol_entry, num_similars): @@ -209,10 +209,6 @@ def optimize( validation_dataset.shuffle(seed=42) model.train() - early_stopping_callback = CustomEarlyStopCallback( - early_stopping_patience=1, - early_stopping_threshold=0.0001 - ) trainer = SFTTrainer( model=model, train_dataset=train_dataset, @@ -223,7 +219,6 @@ def optimize( tokenizer=tokenizer, max_seq_length=config["rej_sample_config"]["max_seq_length"], # data_collator=collator, - callbacks=[early_stopping_callback], optimizers=[optimizer, lr_scheduler], ) trainer.train() diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py index fc7d0ed..eeac662 100644 --- a/chemlactica/mol_opt/tunning.py +++ b/chemlactica/mol_opt/tunning.py @@ -8,30 +8,30 @@ from chemlactica.mol_opt.utils import generate_random_number -class CustomEarlyStopCallback(TrainerCallback): +# class CustomEarlyStopCallback(TrainerCallback): - def __init__(self, early_stopping_patience: int, early_stopping_threshold: float) -> None: - super().__init__() - self.best_valid_loss = math.inf - self.early_stopping_patience = early_stopping_patience - self.current_patiance = 0 - self.early_stopping_threshold = early_stopping_threshold +# def __init__(self, early_stopping_patience: int, early_stopping_threshold: float) -> None: +# super().__init__() +# self.best_valid_loss = math.inf +# self.early_stopping_patience = early_stopping_patience +# self.current_patiance = 0 +# self.early_stopping_threshold = early_stopping_threshold - def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - self.best_valid_loss = math.inf - self.current_patiance = 0 - return super().on_train_begin(args, state, control, **kwargs) +# def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): +# self.best_valid_loss = math.inf +# self.current_patiance = 0 +# return super().on_train_begin(args, state, control, **kwargs) - def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs): - if metrics["eval_loss"] >= self.best_valid_loss - self.early_stopping_threshold: - self.current_patiance += 1 - else: - self.current_patiance = 0 - self.best_valid_loss = metrics["eval_loss"] - print(f"Early Stopping patiance: {self.current_patiance}/{self.early_stopping_patience}") - if self.current_patiance >= self.early_stopping_patience: - control.should_training_stop = True - return super().on_evaluate(args, state, control, **kwargs) +# def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs): +# if metrics["eval_loss"] >= self.best_valid_loss - self.early_stopping_threshold: +# self.current_patiance += 1 +# else: +# self.current_patiance = 0 +# self.best_valid_loss = metrics["eval_loss"] +# print(f"Early Stopping patiance: {self.current_patiance}/{self.early_stopping_patience}") +# if self.current_patiance >= self.early_stopping_patience: +# control.should_training_stop = True +# return super().on_evaluate(args, state, control, **kwargs) # class CustomSFTTrainer(SFTTrainer): @@ -60,7 +60,7 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra def get_training_arguments(config): - checkpoints_dir = config["checkpoints_dir"] + "_" + str(time.time()) + checkpoints_dir = f"{config['checkpoints_dir']}/checkpoint-{time.time():.4f}" return TrainingArguments( output_dir=checkpoints_dir, per_device_train_batch_size=config["train_batch_size"], @@ -68,15 +68,16 @@ def get_training_arguments(config): max_grad_norm=config["global_gradient_norm"], num_train_epochs=config["num_train_epochs"], evaluation_strategy="epoch", - # save_strategy="epoch", + save_strategy="epoch", dataloader_drop_last=False, dataloader_pin_memory=True, dataloader_num_workers=config["dataloader_num_workers"], gradient_accumulation_steps=config["gradient_accumulation_steps"], logging_steps=1, + save_safetensors=False, metric_for_best_model="loss", - # load_best_model_at_end=True, - # save_total_limit=1 + load_best_model_at_end=True, + save_total_limit=1 ) diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 0c9d9cc..16b36e3 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -223,7 +223,9 @@ def get_train_valid_entries(self): return train_entries, valid_entries def random_subset(self, subset_size): - rand_inds = np.random.permutation(min(len(self.optim_entries), subset_size * 2)) + # rand_inds = np.random.permutation(min(len(self.optim_entries), subset_size * 2)) + rand_inds = np.random.permutation(len(self.optim_entries)) + rand_inds = rand_inds[:subset_size] return [self.optim_entries[i] for i in rand_inds] def __len__(self): From a84c79cfd64e9fb9470230497b4467414e205ebc Mon Sep 17 00:00:00 2001 From: tigranfah Date: Sun, 9 Jun 2024 22:41:33 +0400 Subject: [PATCH 41/45] remove model selection with validation loss, because of slowness --- chemlactica/mol_opt/optimization.py | 16 ++++----- chemlactica/mol_opt/tunning.py | 51 +++++++++++++++-------------- chemlactica/mol_opt/utils.py | 2 +- 3 files changed, 35 insertions(+), 34 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 7457614..781c00d 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -1,18 +1,12 @@ from typing import List import torch from datasets import Dataset -import multiprocessing import gc -import math -import tqdm -import random import shutil -from functools import partial from trl import SFTTrainer -import numpy as np from transformers import OPTForCausalLM -from chemlactica.mol_opt.utils import OptimEntry, MoleculeEntry, Pool, generate_random_number, tanimoto_dist_func -from chemlactica.mol_opt.tunning import get_training_arguments, get_optimizer_and_lr_scheduler +from chemlactica.mol_opt.utils import OptimEntry, MoleculeEntry, Pool +from chemlactica.mol_opt.tunning import get_training_arguments, get_optimizer_and_lr_scheduler, CustomEarlyStopCallback def create_similar_mol_entries(pool, mol_entry, num_similars): @@ -208,6 +202,11 @@ def optimize( train_dataset.shuffle(seed=42) validation_dataset.shuffle(seed=42) + early_stopping_callback = CustomEarlyStopCallback( + early_stopping_patience=1, + early_stopping_threshold=0.0001 + ) + model.train() trainer = SFTTrainer( model=model, @@ -219,6 +218,7 @@ def optimize( tokenizer=tokenizer, max_seq_length=config["rej_sample_config"]["max_seq_length"], # data_collator=collator, + callbacks=[early_stopping_callback], optimizers=[optimizer, lr_scheduler], ) trainer.train() diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py index eeac662..cd680f0 100644 --- a/chemlactica/mol_opt/tunning.py +++ b/chemlactica/mol_opt/tunning.py @@ -8,30 +8,30 @@ from chemlactica.mol_opt.utils import generate_random_number -# class CustomEarlyStopCallback(TrainerCallback): +class CustomEarlyStopCallback(TrainerCallback): -# def __init__(self, early_stopping_patience: int, early_stopping_threshold: float) -> None: -# super().__init__() -# self.best_valid_loss = math.inf -# self.early_stopping_patience = early_stopping_patience -# self.current_patiance = 0 -# self.early_stopping_threshold = early_stopping_threshold + def __init__(self, early_stopping_patience: int, early_stopping_threshold: float) -> None: + super().__init__() + self.best_valid_loss = math.inf + self.early_stopping_patience = early_stopping_patience + self.current_patiance = 0 + self.early_stopping_threshold = early_stopping_threshold -# def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): -# self.best_valid_loss = math.inf -# self.current_patiance = 0 -# return super().on_train_begin(args, state, control, **kwargs) + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + self.best_valid_loss = math.inf + self.current_patiance = 0 + return super().on_train_begin(args, state, control, **kwargs) -# def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs): -# if metrics["eval_loss"] >= self.best_valid_loss - self.early_stopping_threshold: -# self.current_patiance += 1 -# else: -# self.current_patiance = 0 -# self.best_valid_loss = metrics["eval_loss"] -# print(f"Early Stopping patiance: {self.current_patiance}/{self.early_stopping_patience}") -# if self.current_patiance >= self.early_stopping_patience: -# control.should_training_stop = True -# return super().on_evaluate(args, state, control, **kwargs) + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs): + if metrics["eval_loss"] >= self.best_valid_loss - self.early_stopping_threshold: + self.current_patiance += 1 + else: + self.current_patiance = 0 + self.best_valid_loss = metrics["eval_loss"] + print(f"Early Stopping patiance: {self.current_patiance}/{self.early_stopping_patience}") + if self.current_patiance >= self.early_stopping_patience: + control.should_training_stop = True + return super().on_evaluate(args, state, control, **kwargs) # class CustomSFTTrainer(SFTTrainer): @@ -68,7 +68,8 @@ def get_training_arguments(config): max_grad_norm=config["global_gradient_norm"], num_train_epochs=config["num_train_epochs"], evaluation_strategy="epoch", - save_strategy="epoch", + # save_strategy="epoch", + save_strategy="no", dataloader_drop_last=False, dataloader_pin_memory=True, dataloader_num_workers=config["dataloader_num_workers"], @@ -76,8 +77,8 @@ def get_training_arguments(config): logging_steps=1, save_safetensors=False, metric_for_best_model="loss", - load_best_model_at_end=True, - save_total_limit=1 + # load_best_model_at_end=True, + # save_total_limit=1 ) @@ -95,4 +96,4 @@ def get_optimizer_and_lr_scheduler(model, config, max_train_steps): lr_end=config["lr_end"], power=1.0, ) - return optimizer, lr_scheduler + return optimizer, lr_scheduler \ No newline at end of file diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 16b36e3..bc231f5 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -162,7 +162,7 @@ class Pool: def __init__(self, size, validation_perc: float): self.size = size self.optim_entries: List[OptimEntry] = [] - self.num_validation_entries = int(size * validation_perc) + self.num_validation_entries = int(size * validation_perc + 1) # def random_dump(self, num): # for _ in range(num): From c393b1aae07ce6eba79ede8a0443447af50c5cba Mon Sep 17 00:00:00 2001 From: tigranfah Date: Tue, 11 Jun 2024 15:54:57 +0400 Subject: [PATCH 42/45] take random permutation of top elements from the pool, when constructing a prompt --- chemlactica/mol_opt/tunning.py | 1 - chemlactica/mol_opt/utils.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py index cd680f0..289eefe 100644 --- a/chemlactica/mol_opt/tunning.py +++ b/chemlactica/mol_opt/tunning.py @@ -33,7 +33,6 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra control.should_training_stop = True return super().on_evaluate(args, state, control, **kwargs) - # class CustomSFTTrainer(SFTTrainer): # # def __init__(self, *args, patience, toll, **kwargs): diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index bc231f5..b018a0e 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -223,9 +223,9 @@ def get_train_valid_entries(self): return train_entries, valid_entries def random_subset(self, subset_size): - # rand_inds = np.random.permutation(min(len(self.optim_entries), subset_size * 2)) - rand_inds = np.random.permutation(len(self.optim_entries)) - rand_inds = rand_inds[:subset_size] + rand_inds = np.random.permutation(min(len(self.optim_entries), subset_size)) + # rand_inds = np.random.permutation(len(self.optim_entries)) + # rand_inds = rand_inds[:subset_size] return [self.optim_entries[i] for i in rand_inds] def __len__(self): From 5537720e8e08eb751c261c2c366e49ae944a6eff Mon Sep 17 00:00:00 2001 From: tigranfah Date: Wed, 12 Jun 2024 01:12:40 +0400 Subject: [PATCH 43/45] add custom model selection with fine-tuning validation loss --- chemlactica/mol_opt/optimization.py | 15 +++++++++------ chemlactica/mol_opt/tunning.py | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 781c00d..b1f17ca 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -6,7 +6,7 @@ from trl import SFTTrainer from transformers import OPTForCausalLM from chemlactica.mol_opt.utils import OptimEntry, MoleculeEntry, Pool -from chemlactica.mol_opt.tunning import get_training_arguments, get_optimizer_and_lr_scheduler, CustomEarlyStopCallback +from chemlactica.mol_opt.tunning import get_training_arguments, get_optimizer_and_lr_scheduler, CustomEarlyStopCallback, CustomModelSelectionCallback def create_similar_mol_entries(pool, mol_entry, num_similars): @@ -202,10 +202,11 @@ def optimize( train_dataset.shuffle(seed=42) validation_dataset.shuffle(seed=42) - early_stopping_callback = CustomEarlyStopCallback( - early_stopping_patience=1, - early_stopping_threshold=0.0001 - ) + # early_stopping_callback = CustomEarlyStopCallback( + # early_stopping_patience=1, + # early_stopping_threshold=0.0001 + # ) + model_selection_callback = CustomModelSelectionCallback() model.train() trainer = SFTTrainer( @@ -218,10 +219,12 @@ def optimize( tokenizer=tokenizer, max_seq_length=config["rej_sample_config"]["max_seq_length"], # data_collator=collator, - callbacks=[early_stopping_callback], + callbacks=[model_selection_callback], optimizers=[optimizer, lr_scheduler], ) trainer.train() + print(f"Loading the best model state dict with validation loss {model_selection_callback.best_validation_loss}") + model.load_state_dict(model_selection_callback.best_model_state_dict) shutil.rmtree(training_args.output_dir) gc.collect() torch.cuda.empty_cache() diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py index 289eefe..1332f0e 100644 --- a/chemlactica/mol_opt/tunning.py +++ b/chemlactica/mol_opt/tunning.py @@ -5,6 +5,7 @@ import torch import math import time +from collections import OrderedDict from chemlactica.mol_opt.utils import generate_random_number @@ -33,6 +34,23 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra control.should_training_stop = True return super().on_evaluate(args, state, control, **kwargs) + +class CustomModelSelectionCallback(TrainerCallback): + + def __init__(self): + super().__init__() + self.best_validation_loss: float = math.inf + self.best_model_state_dict: OrderedDict = OrderedDict() + + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, model, **kwargs): + if metrics["eval_loss"] <= self.best_validation_loss: + self.best_validation_loss = metrics["eval_loss"] + print(f"Better validation loss achieved {self.best_validation_loss}, updating the state dict.") + for key, value in model.state_dict().items(): + self.best_model_state_dict[key] = value.detach().clone() + return super().on_evaluate(args, state, control, **kwargs) + + # class CustomSFTTrainer(SFTTrainer): # # def __init__(self, *args, patience, toll, **kwargs): From 10feca89379df7cf1205bd0eec0eaf99a5495df4 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Thu, 20 Jun 2024 09:53:54 +0400 Subject: [PATCH 44/45] take a random subset from pool, when creating a prompt --- chemlactica/mol_opt/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index b018a0e..b44f37b 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -223,9 +223,9 @@ def get_train_valid_entries(self): return train_entries, valid_entries def random_subset(self, subset_size): - rand_inds = np.random.permutation(min(len(self.optim_entries), subset_size)) - # rand_inds = np.random.permutation(len(self.optim_entries)) - # rand_inds = rand_inds[:subset_size] + # rand_inds = np.random.permutation(min(len(self.optim_entries), subset_size)) + rand_inds = np.random.permutation(len(self.optim_entries)) + rand_inds = rand_inds[:subset_size] return [self.optim_entries[i] for i in rand_inds] def __len__(self): From 52692ecf32d7e1ca4b938042914faf760ab4a18d Mon Sep 17 00:00:00 2001 From: tigranfah Date: Mon, 24 Jun 2024 17:40:40 +0400 Subject: [PATCH 45/45] fix --- chemlactica/mol_opt/optimization.py | 16 +++++++++++----- chemlactica/mol_opt/utils.py | 15 ++++----------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index b1f17ca..9cb2897 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -142,8 +142,8 @@ def optimize( current_unique_optim_entries[smiles].last_entry.score = oracle_score iter_unique_optim_entries[smiles] = current_unique_optim_entries[smiles] file.write(f"generated smiles: {smiles}, score: {current_unique_optim_entries[smiles].last_entry.score:.4f}\n") - if current_unique_optim_entries[smiles].last_entry.score > max_score: - max_score = current_unique_optim_entries[smiles].last_entry.score + if max_score >= config["max_possible_oracle_score"] - 1e-2 or current_unique_optim_entries[smiles].last_entry.score > max_score: + max_score = max(max_score, current_unique_optim_entries[smiles].last_entry.score) new_best_molecule_generated = True print(f"Iter unique optim entries: {len(iter_unique_optim_entries)}, budget: {len(oracle)}") @@ -189,13 +189,19 @@ def optimize( train_dataset = Dataset.from_dict({ "sample": [ - optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config) + optim_entry.to_prompt( + is_generation=False, include_oracle_score=True, + config=config, max_score=config["max_possible_oracle_score"] + ) for optim_entry in train_entries ] }) validation_dataset = Dataset.from_dict({ "sample": [ - optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config) + optim_entry.to_prompt( + is_generation=False, include_oracle_score=True, + config=config, max_score=config["max_possible_oracle_score"] + ) for optim_entry in validation_entries ] }) @@ -225,7 +231,7 @@ def optimize( trainer.train() print(f"Loading the best model state dict with validation loss {model_selection_callback.best_validation_loss}") model.load_state_dict(model_selection_callback.best_model_state_dict) - shutil.rmtree(training_args.output_dir) + del model_selection_callback.best_model_state_dict gc.collect() torch.cuda.empty_cache() tol_level = 0 diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index b44f37b..ec6d2af 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -270,7 +270,7 @@ def __init__(self, last_entry, mol_entries): def to_prompt( self, is_generation: bool, include_oracle_score: bool, config, - max_score=None + max_score ): prompt = "" # prompt = config["eos_token"] @@ -308,16 +308,9 @@ def to_prompt( pass elif "rej-sample-v2" in config["strategy"]: if is_generation: - # oracle_scores_of_mols_in_prompt = [e.score for e in self.mol_entries] - # q_0_9 = ( - # np.quantile(oracle_scores_of_mols_in_prompt, 0.9) - # if oracle_scores_of_mols_in_prompt - # else 0 - # ) - # desired_oracle_score = generate_random_number( - # q_0_9, config["max_possible_oracle_score"] - # ) - desired_oracle_score = max_score + desired_oracle_score = generate_random_number( + max_score, config["max_possible_oracle_score"] + ) oracle_score = desired_oracle_score else: oracle_score = self.last_entry.score