Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mol opt #24

Merged
merged 48 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
9587e84
add mol opt
tigranfah Apr 19, 2024
2a9cc8c
mol opt+rej sample
tigranfah Apr 22, 2024
ab876f9
minor logging changes
tigranfah Apr 22, 2024
56e9232
rej-sample-v1
tigranfah Apr 24, 2024
de7a695
add molecules pool dump
tigranfah Apr 27, 2024
40f7dde
add pool dump
tigranfah Apr 29, 2024
7559e5b
add feature
tigranfah May 3, 2024
a023281
don't enter scoring code if oracle has exceeded budget
philippguevorguian May 7, 2024
8ae3162
add file generation utility
philippguevorguian May 7, 2024
f2ee468
rej-sample-v2
tigranfah May 10, 2024
44cfb4b
rej-sample-v2.1
tigranfah May 13, 2024
898f5c1
refine
tigranfah May 14, 2024
a68b4f2
rej-sample-v2
tigranfah May 10, 2024
f5e1ca2
rej-sample-v2.1
tigranfah May 13, 2024
52b3a64
merge
tigranfah May 14, 2024
ca32500
merge
tigranfah May 14, 2024
cff4db7
rej-sample-v2 refac
tigranfah May 15, 2024
efd8106
remove dublicates if any from the optim process
tigranfah May 17, 2024
cade253
add hash type for molecular entries to store
philippguevorguian May 17, 2024
70642da
merge with local
philippguevorguian May 17, 2024
a805c1b
add logit processors to package
philippguevorguian May 17, 2024
20d3fe0
add train condition
tigranfah May 17, 2024
bfa32d7
don't give oracle score tag before the first training
tigranfah May 18, 2024
d51cbfb
add validation set to fine tunning
tigranfah May 20, 2024
1502e65
replace [ORACLE_SCORE] with [PROPERTY] tag
tigranfah May 20, 2024
188f384
add additional properties
tigranfah May 21, 2024
d95e4d4
correct properties order
tigranfah May 21, 2024
800278b
fix entry dublicates in pool issue
tigranfah May 21, 2024
1306842
pre merge
philippguevorguian May 22, 2024
d9f85d1
Merge branch 'mol_opt' of https://github.com/YerevaNN/ChemLactica int…
philippguevorguian May 22, 2024
b98a2cb
add validation batch size, to avoid memory error
tigranfah May 22, 2024
32cc5aa
small fix
tigranfah May 22, 2024
0ac3a7b
small fixes
tigranfah May 25, 2024
d8ac65f
add hparam config
tigranfah May 29, 2024
05af82d
add no slurm parallel runs to mol_opt
tigranfah May 29, 2024
16a00ca
keep optim state during the fine-tuning of mol optim process
tigranfah May 29, 2024
832906c
refine
tigranfah Jun 3, 2024
1f8142f
add lr annealing/not annealing
tigranfah Jun 6, 2024
91d9390
add ending lr in config
tigranfah Jun 6, 2024
2cf0f8b
remove multiprocessing
tigranfah Jun 6, 2024
9c96919
change sampling from pool
tigranfah Jun 7, 2024
5d7a258
fix
tigranfah Jun 7, 2024
be4096d
add model selection based on validation loss
tigranfah Jun 8, 2024
a84c79c
remove model selection with validation loss, because of slowness
tigranfah Jun 9, 2024
c393b1a
take random permutation of top elements from the pool, when construct…
tigranfah Jun 11, 2024
5537720
add custom model selection with fine-tuning validation loss
tigranfah Jun 11, 2024
10feca8
take a random subset from pool, when creating a prompt
tigranfah Jun 20, 2024
52692ec
fix
tigranfah Jun 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
chemlactica/mol_opt/results/

# Aim experiment metadata
aim/

Expand Down
Empty file added chemlactica/mol_opt/__init__.py
Empty file.
40 changes: 40 additions & 0 deletions chemlactica/mol_opt/chemlactica_125m_hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# checkpoint_path: /nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/9954e52e400b43d18d3a40f6/checkpoint-20480
# checkpoint_path: /nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/9954e52e400b43d18d3a40f6/checkpoint-12288
checkpoint_path: /home/admin/checkpoints/facebook/galactica-125m/1f289ff103034364bd27e1c3/checkpoint-18000
tokenizer_path: /home/admin/tigran/ChemLactica/chemlactica/tokenizer/ChemLacticaTokenizer66
pool_size: 50
validation_perc: 0.2
num_mols: 0
num_similars: 1
num_gens_per_iter: 200
device: cuda:0
sim_range: [0.8, 0.9]
# qed_range: [0.5, 0.9]
num_processes: 8
generation_batch_size: 200
eos_token: "</s>"
generation_temperature: [1.0, 1.5]

generation_config:
repetition_penalty: 1.0
max_new_tokens: 100
do_sample: true
eos_token_id: 20

strategy: [default]

rej_sample_config:
train_tol_level: 3
checkpoints_dir: ./
max_learning_rate: 0.00001
train_batch_size: 2
gradient_accumulation_steps: 8
weight_decay: 0.1
adam_beta1: 0.9
adam_beta2: 0.999
warmup_steps: 0
global_gradient_norm: 1.0
dataloader_num_workers: 1
max_seq_length: 2048
num_train_epochs: 5
packing: false
55 changes: 55 additions & 0 deletions chemlactica/mol_opt/hparam_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import yaml
import datetime
import argparse
import os
from utils import ConstraedTPSAOracle
from typing import List
from chemlactica.mol_opt.optimization import optimize

os.environ["TOKENIZERS_PARALLELISM"] = "true"


def default_train_condition(num_iter, tol_level, prev_train_iter):
return num_iter - prev_train_iter >= 3


def tolerance_train_condition(cur_tol_level, train_tol_level):
return cur_tol_level >= train_tol_level


def choose_train_condition(name):
return {
"default" : default_train_condition,
"tolerance": tolerance_train_condition
}[name]


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--run_name", type=str, required=False)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--config_default", type=str, required=False, default="chemlactica/chemlactica_125m_hparams.yaml")
parser.add_argument("--n_runs", type=int, required=False, default=1)
args = parser.parse_args()
return args


if __name__ == "__main__":
args = parse_arguments()
config = yaml.safe_load(open(args.config_default))
print(config)

model = AutoModelForCausalLM.from_pretrained(config["checkpoint_path"], torch_dtype=torch.bfloat16).to(config["device"])
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_path"], padding_side="left")

seeds = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31]
oracle = ConstraedTPSAOracle(max_oracle_calls=15000)
for seed in seeds[:args.n_runs]:
config["log_dir"] = os.path.join(args.output_dir, "results_tpsa+weight+num_rungs.log")
config["rej_sample_config"]["should_train"] = choose_train_condition("tolerance")
optimize(
model, tokenizer,
oracle, config
)
18 changes: 18 additions & 0 deletions chemlactica/mol_opt/hparams_tune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: chemlactica
method: grid
metric:
goal: maximize
name: avg_auc
parameters:
strategy: [[default]]

pool_size: [10, 30, 50]
num_mols: [0, 1, 2, 3, 5]
num_similars: [0, 1, 2, 3, 5]
num_gens_per_iter: [200, 400, 600]
generation_temperature: [[1.0, 1.0], [1.5, 1.5], [1.0, 1.5]]

# rej_sample_config:
# num_train_epochs: [1, 3, 5, 7, 9]
# train_tol_level: [1, 3, 5, 7, 9]
# max_learning_rate: [0.0001, 0.00001, 0.000001]
71 changes: 71 additions & 0 deletions chemlactica/mol_opt/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import torch


def top_auc(buffer, top_n, finish, freq_log, max_oracle_calls):
sum = 0
prev = 0
called = 0
ordered_results = list(sorted(buffer.items(), key=lambda kv: kv[1][1], reverse=False))
for idx in range(freq_log, min(len(buffer), max_oracle_calls), freq_log):
temp_result = ordered_results[:idx]
temp_result = list(sorted(temp_result, key=lambda kv: kv[1][0], reverse=True))[:top_n]
top_n_now = np.mean([item[1][0] for item in temp_result])
sum += freq_log * (top_n_now + prev) / 2
prev = top_n_now
called = idx
temp_result = list(sorted(ordered_results, key=lambda kv: kv[1][0], reverse=True))[:top_n]
top_n_now = np.mean([item[1][0] for item in temp_result])
sum += (len(buffer) - called) * (top_n_now + prev) / 2
if finish and len(buffer) < max_oracle_calls:
sum += (max_oracle_calls - len(buffer)) * top_n_now
return sum / max_oracle_calls


def average_agg_tanimoto(stock_vecs, gen_vecs,
batch_size=5000, agg='max',
device='cpu', p=1):
"""
For each molecule in gen_vecs finds closest molecule in stock_vecs.
Returns average tanimoto score for between these molecules

Parameters:
stock_vecs: numpy array <n_vectors x dim>
gen_vecs: numpy array <n_vectors' x dim>
agg: max or mean
p: power for averaging: (mean x^p)^(1/p)
"""
assert agg in ['max', 'mean'], "Can aggregate only max or mean"
agg_tanimoto = np.zeros(len(gen_vecs))
total = np.zeros(len(gen_vecs))
for j in range(0, stock_vecs.shape[0], batch_size):
x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
for i in range(0, gen_vecs.shape[0], batch_size):
y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
y_gen = y_gen.transpose(0, 1)
tp = torch.mm(x_stock, y_gen)
jac = (tp / (x_stock.sum(1, keepdim=True) +
y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
jac[np.isnan(jac)] = 1
if p != 1:
jac = jac**p
if agg == 'max':
agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
elif agg == 'mean':
agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
total[i:i + y_gen.shape[1]] += jac.shape[0]
if agg == 'mean':
agg_tanimoto /= total
if p != 1:
agg_tanimoto = (agg_tanimoto)**(1/p)
return np.mean(agg_tanimoto)


def internal_diversity(molecule_fingerprints, device='cpu', fp_type='morgan', p=1):
"""
Computes internal diversity as:
1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y))
"""
return 1 - (average_agg_tanimoto(molecule_fingerprints, molecule_fingerprints,
agg='mean', device=device, p=p)).mean()
119 changes: 119 additions & 0 deletions chemlactica/mol_opt/no_slurm_hparam_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import submitit
import subprocess
import itertools as it
import datetime
import yaml
import os
import copy
import time
import torch


def is_gpu_being_used(gpu_id):
try:
# Run the nvidia-smi command
cmd = ['nvidia-smi','-i',f"{gpu_id}"]
output = subprocess.check_output(cmd)
output = output.decode('utf-8')
if "No running processes found" in output:
return False
else:
return True

except subprocess.CalledProcessError as e:
print(f"Error executing nvidia-smi command: {e}")


def create_hparam_configs(config_file_path):
config_tune = yaml.safe_load(open("hparams_tune.yaml"))
config_merged = {}
for key, value in config_tune["parameters"].items():
if type(value) == list:
config_merged[key] = value
else:
for k, v in value.items():
config_merged[key+'+'+k] = v

config_default = yaml.safe_load(open(config_file_path))
hparam_names = list(config_merged.keys())
all_configs = []
for params in it.product(*config_merged.values()):
# pprint(params)
# pprint(hparam_names)
config = copy.deepcopy(config_default)
for i, p in enumerate(params):
if '+' in hparam_names[i]:
a, b = hparam_names[i].split("+")
config[a][b] = p
else:
config[hparam_names[i]] = p
# pprint(params)
# pprint(config)
all_configs.append(config)
# print(config)
return all_configs


if __name__ == "__main__":
n_runs = 3

config_file_path = "chemlactica_125m_hparams.yaml"
# config_file_path = "main/chemlactica/chemma_2b_hparams.yaml"
hparam_configs = create_hparam_configs(config_file_path)
# infer_config = [yaml.safe_load(open(config_file_path))]
model_name = "-".join(config_file_path.split("/")[-1].split("_")[:2])
gpu_indices = [0, 1, 2, 3, 4, 5, 6, 7]

index = 0
while index < len(hparam_configs):
free_gpu_index = None
for gpu_index in gpu_indices:
gpu_is_free = True
print(f"Checking gpu: {gpu_index}")
for _ in range(10):
if is_gpu_being_used(gpu_index):
gpu_is_free = False
break
time.sleep(1)
if gpu_is_free:
free_gpu_index = gpu_index
print(f"gpu: {gpu_index} is free")
break
else:
print(f"gpu: {gpu_index} is being used")
if free_gpu_index is not None:
print(f"found a free gpu {free_gpu_index}, putting a job")
executor = submitit.LocalExecutor(folder="/home/admin/tigran/slurm_jobs/PMO/job_%j")
executor.update_parameters(
name="chemlactica-pmo", timeout_min=n_runs * 12 * 60,
visible_gpus=[free_gpu_index],
gpus_per_node=1, nodes=1, mem_gb=80, cpus_per_task=8,
slurm_array_parallelism=10
)
jobs = []
with executor.batch():
current_hparams = [hparam_configs[index]]
for config in current_hparams:
formatted_date_time = datetime.datetime.now().strftime("%Y-%m-%d")
base = f"results/{formatted_date_time}"
os.makedirs(base, exist_ok=True)
v = 0
name = model_name + "-" + "+".join(config["strategy"])
while os.path.exists(os.path.join(base, f"{name}-{v}-hparam-search")):
v += 1
output_dir = os.path.join(base, f"{name}-{v}-hparam-search")
os.makedirs(output_dir, exist_ok=True)
yaml.safe_dump(config, open(os.path.join(output_dir, "hparams.yaml"), "w"))
function = submitit.helpers.CommandFunction([
'python3', 'hparam_search.py',
'--config_default', os.path.join(output_dir, "hparams.yaml"),
'--output_dir', output_dir,
'--n_runs', str(n_runs),
])
print(' '.join(function.command))
job = executor.submit(function)
jobs.append(job)
for job in jobs:
print(job.job_id)
index += 1
free_gpu_index = None
Loading
Loading