diff --git a/.gitignore b/.gitignore index 5919610..2b2f680 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +chemlactica/mol_opt/results/ + # Aim experiment metadata aim/ diff --git a/chemlactica/mol_opt/__init__.py b/chemlactica/mol_opt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chemlactica/mol_opt/chemlactica_125m_hparams.yaml b/chemlactica/mol_opt/chemlactica_125m_hparams.yaml new file mode 100644 index 0000000..f9b8ad3 --- /dev/null +++ b/chemlactica/mol_opt/chemlactica_125m_hparams.yaml @@ -0,0 +1,40 @@ +# checkpoint_path: /nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/9954e52e400b43d18d3a40f6/checkpoint-20480 +# checkpoint_path: /nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/9954e52e400b43d18d3a40f6/checkpoint-12288 +checkpoint_path: /home/admin/checkpoints/facebook/galactica-125m/1f289ff103034364bd27e1c3/checkpoint-18000 +tokenizer_path: /home/admin/tigran/ChemLactica/chemlactica/tokenizer/ChemLacticaTokenizer66 +pool_size: 50 +validation_perc: 0.2 +num_mols: 0 +num_similars: 1 +num_gens_per_iter: 200 +device: cuda:0 +sim_range: [0.8, 0.9] +# qed_range: [0.5, 0.9] +num_processes: 8 +generation_batch_size: 200 +eos_token: "" +generation_temperature: [1.0, 1.5] + +generation_config: + repetition_penalty: 1.0 + max_new_tokens: 100 + do_sample: true + eos_token_id: 20 + +strategy: [default] + +rej_sample_config: + train_tol_level: 3 + checkpoints_dir: ./ + max_learning_rate: 0.00001 + train_batch_size: 2 + gradient_accumulation_steps: 8 + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.999 + warmup_steps: 0 + global_gradient_norm: 1.0 + dataloader_num_workers: 1 + max_seq_length: 2048 + num_train_epochs: 5 + packing: false \ No newline at end of file diff --git a/chemlactica/mol_opt/hparam_search.py b/chemlactica/mol_opt/hparam_search.py new file mode 100644 index 0000000..4730f78 --- /dev/null +++ b/chemlactica/mol_opt/hparam_search.py @@ -0,0 +1,55 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +import yaml +import datetime +import argparse +import os +from utils import ConstraedTPSAOracle +from typing import List +from chemlactica.mol_opt.optimization import optimize + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + + +def default_train_condition(num_iter, tol_level, prev_train_iter): + return num_iter - prev_train_iter >= 3 + + +def tolerance_train_condition(cur_tol_level, train_tol_level): + return cur_tol_level >= train_tol_level + + +def choose_train_condition(name): + return { + "default" : default_train_condition, + "tolerance": tolerance_train_condition + }[name] + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--run_name", type=str, required=False) + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--config_default", type=str, required=False, default="chemlactica/chemlactica_125m_hparams.yaml") + parser.add_argument("--n_runs", type=int, required=False, default=1) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_arguments() + config = yaml.safe_load(open(args.config_default)) + print(config) + + model = AutoModelForCausalLM.from_pretrained(config["checkpoint_path"], torch_dtype=torch.bfloat16).to(config["device"]) + tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_path"], padding_side="left") + + seeds = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31] + oracle = ConstraedTPSAOracle(max_oracle_calls=15000) + for seed in seeds[:args.n_runs]: + config["log_dir"] = os.path.join(args.output_dir, "results_tpsa+weight+num_rungs.log") + config["rej_sample_config"]["should_train"] = choose_train_condition("tolerance") + optimize( + model, tokenizer, + oracle, config + ) \ No newline at end of file diff --git a/chemlactica/mol_opt/hparams_tune.yaml b/chemlactica/mol_opt/hparams_tune.yaml new file mode 100644 index 0000000..71721f6 --- /dev/null +++ b/chemlactica/mol_opt/hparams_tune.yaml @@ -0,0 +1,18 @@ +name: chemlactica +method: grid +metric: + goal: maximize + name: avg_auc +parameters: + strategy: [[default]] + + pool_size: [10, 30, 50] + num_mols: [0, 1, 2, 3, 5] + num_similars: [0, 1, 2, 3, 5] + num_gens_per_iter: [200, 400, 600] + generation_temperature: [[1.0, 1.0], [1.5, 1.5], [1.0, 1.5]] + + # rej_sample_config: + # num_train_epochs: [1, 3, 5, 7, 9] + # train_tol_level: [1, 3, 5, 7, 9] + # max_learning_rate: [0.0001, 0.00001, 0.000001] \ No newline at end of file diff --git a/chemlactica/mol_opt/metrics.py b/chemlactica/mol_opt/metrics.py new file mode 100644 index 0000000..f063171 --- /dev/null +++ b/chemlactica/mol_opt/metrics.py @@ -0,0 +1,71 @@ +import numpy as np +import torch + + +def top_auc(buffer, top_n, finish, freq_log, max_oracle_calls): + sum = 0 + prev = 0 + called = 0 + ordered_results = list(sorted(buffer.items(), key=lambda kv: kv[1][1], reverse=False)) + for idx in range(freq_log, min(len(buffer), max_oracle_calls), freq_log): + temp_result = ordered_results[:idx] + temp_result = list(sorted(temp_result, key=lambda kv: kv[1][0], reverse=True))[:top_n] + top_n_now = np.mean([item[1][0] for item in temp_result]) + sum += freq_log * (top_n_now + prev) / 2 + prev = top_n_now + called = idx + temp_result = list(sorted(ordered_results, key=lambda kv: kv[1][0], reverse=True))[:top_n] + top_n_now = np.mean([item[1][0] for item in temp_result]) + sum += (len(buffer) - called) * (top_n_now + prev) / 2 + if finish and len(buffer) < max_oracle_calls: + sum += (max_oracle_calls - len(buffer)) * top_n_now + return sum / max_oracle_calls + + +def average_agg_tanimoto(stock_vecs, gen_vecs, + batch_size=5000, agg='max', + device='cpu', p=1): + """ + For each molecule in gen_vecs finds closest molecule in stock_vecs. + Returns average tanimoto score for between these molecules + + Parameters: + stock_vecs: numpy array + gen_vecs: numpy array + agg: max or mean + p: power for averaging: (mean x^p)^(1/p) + """ + assert agg in ['max', 'mean'], "Can aggregate only max or mean" + agg_tanimoto = np.zeros(len(gen_vecs)) + total = np.zeros(len(gen_vecs)) + for j in range(0, stock_vecs.shape[0], batch_size): + x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float() + for i in range(0, gen_vecs.shape[0], batch_size): + y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float() + y_gen = y_gen.transpose(0, 1) + tp = torch.mm(x_stock, y_gen) + jac = (tp / (x_stock.sum(1, keepdim=True) + + y_gen.sum(0, keepdim=True) - tp)).cpu().numpy() + jac[np.isnan(jac)] = 1 + if p != 1: + jac = jac**p + if agg == 'max': + agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum( + agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0)) + elif agg == 'mean': + agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0) + total[i:i + y_gen.shape[1]] += jac.shape[0] + if agg == 'mean': + agg_tanimoto /= total + if p != 1: + agg_tanimoto = (agg_tanimoto)**(1/p) + return np.mean(agg_tanimoto) + + +def internal_diversity(molecule_fingerprints, device='cpu', fp_type='morgan', p=1): + """ + Computes internal diversity as: + 1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y)) + """ + return 1 - (average_agg_tanimoto(molecule_fingerprints, molecule_fingerprints, + agg='mean', device=device, p=p)).mean() \ No newline at end of file diff --git a/chemlactica/mol_opt/no_slurm_hparam_search.py b/chemlactica/mol_opt/no_slurm_hparam_search.py new file mode 100644 index 0000000..219c0c9 --- /dev/null +++ b/chemlactica/mol_opt/no_slurm_hparam_search.py @@ -0,0 +1,119 @@ +import submitit +import subprocess +import itertools as it +import datetime +import yaml +import os +import copy +import time +import torch + + +def is_gpu_being_used(gpu_id): + try: + # Run the nvidia-smi command + cmd = ['nvidia-smi','-i',f"{gpu_id}"] + output = subprocess.check_output(cmd) + output = output.decode('utf-8') + if "No running processes found" in output: + return False + else: + return True + + except subprocess.CalledProcessError as e: + print(f"Error executing nvidia-smi command: {e}") + + +def create_hparam_configs(config_file_path): + config_tune = yaml.safe_load(open("hparams_tune.yaml")) + config_merged = {} + for key, value in config_tune["parameters"].items(): + if type(value) == list: + config_merged[key] = value + else: + for k, v in value.items(): + config_merged[key+'+'+k] = v + + config_default = yaml.safe_load(open(config_file_path)) + hparam_names = list(config_merged.keys()) + all_configs = [] + for params in it.product(*config_merged.values()): + # pprint(params) + # pprint(hparam_names) + config = copy.deepcopy(config_default) + for i, p in enumerate(params): + if '+' in hparam_names[i]: + a, b = hparam_names[i].split("+") + config[a][b] = p + else: + config[hparam_names[i]] = p + # pprint(params) + # pprint(config) + all_configs.append(config) + # print(config) + return all_configs + + +if __name__ == "__main__": + n_runs = 3 + + config_file_path = "chemlactica_125m_hparams.yaml" + # config_file_path = "main/chemlactica/chemma_2b_hparams.yaml" + hparam_configs = create_hparam_configs(config_file_path) + # infer_config = [yaml.safe_load(open(config_file_path))] + model_name = "-".join(config_file_path.split("/")[-1].split("_")[:2]) + gpu_indices = [0, 1, 2, 3, 4, 5, 6, 7] + + index = 0 + while index < len(hparam_configs): + free_gpu_index = None + for gpu_index in gpu_indices: + gpu_is_free = True + print(f"Checking gpu: {gpu_index}") + for _ in range(10): + if is_gpu_being_used(gpu_index): + gpu_is_free = False + break + time.sleep(1) + if gpu_is_free: + free_gpu_index = gpu_index + print(f"gpu: {gpu_index} is free") + break + else: + print(f"gpu: {gpu_index} is being used") + if free_gpu_index is not None: + print(f"found a free gpu {free_gpu_index}, putting a job") + executor = submitit.LocalExecutor(folder="/home/admin/tigran/slurm_jobs/PMO/job_%j") + executor.update_parameters( + name="chemlactica-pmo", timeout_min=n_runs * 12 * 60, + visible_gpus=[free_gpu_index], + gpus_per_node=1, nodes=1, mem_gb=80, cpus_per_task=8, + slurm_array_parallelism=10 + ) + jobs = [] + with executor.batch(): + current_hparams = [hparam_configs[index]] + for config in current_hparams: + formatted_date_time = datetime.datetime.now().strftime("%Y-%m-%d") + base = f"results/{formatted_date_time}" + os.makedirs(base, exist_ok=True) + v = 0 + name = model_name + "-" + "+".join(config["strategy"]) + while os.path.exists(os.path.join(base, f"{name}-{v}-hparam-search")): + v += 1 + output_dir = os.path.join(base, f"{name}-{v}-hparam-search") + os.makedirs(output_dir, exist_ok=True) + yaml.safe_dump(config, open(os.path.join(output_dir, "hparams.yaml"), "w")) + function = submitit.helpers.CommandFunction([ + 'python3', 'hparam_search.py', + '--config_default', os.path.join(output_dir, "hparams.yaml"), + '--output_dir', output_dir, + '--n_runs', str(n_runs), + ]) + print(' '.join(function.command)) + job = executor.submit(function) + jobs.append(job) + for job in jobs: + print(job.job_id) + index += 1 + free_gpu_index = None \ No newline at end of file diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py new file mode 100644 index 0000000..9cb2897 --- /dev/null +++ b/chemlactica/mol_opt/optimization.py @@ -0,0 +1,238 @@ +from typing import List +import torch +from datasets import Dataset +import gc +import shutil +from trl import SFTTrainer +from transformers import OPTForCausalLM +from chemlactica.mol_opt.utils import OptimEntry, MoleculeEntry, Pool +from chemlactica.mol_opt.tunning import get_training_arguments, get_optimizer_and_lr_scheduler, CustomEarlyStopCallback, CustomModelSelectionCallback + + +def create_similar_mol_entries(pool, mol_entry, num_similars): + similar_entries = [e.last_entry for e in pool.random_subset(num_similars)] + count = 0 + valid_similar_entries = [] + for similar_entry in similar_entries: + if count >= num_similars: + break + if similar_entry == mol_entry: + continue + valid_similar_entries.append(similar_entry) + count += 1 + return valid_similar_entries + + +def create_optimization_entries(num_entries, pool, config): + optim_entries = [] + for i in range(num_entries): + mol_entries = [e.last_entry for e in pool.random_subset(config["num_mols"])] + entries = [] + for mol_entry in mol_entries: + similar_mol_entries = create_similar_mol_entries(pool, mol_entry, num_similars=config["num_similars"]) + mol_entry.similar_mol_entries = similar_mol_entries + entries.append(mol_entry) + optim_entries.append(OptimEntry(None, entries)) + return optim_entries + + +def create_molecule_entry(output_text): + start_smiles_tag, end_smiles_tag = "[START_SMILES]", "[END_SMILES]" + start_ind = output_text.rfind(start_smiles_tag) + end_ind = output_text.rfind(end_smiles_tag) + if start_ind == -1 or end_ind == -1: + return None + generated_smiles = output_text[start_ind+len(start_smiles_tag):end_ind] + if len(generated_smiles) == 0: + return None + + try: + molecule = MoleculeEntry( + smiles=generated_smiles, + ) + return molecule + except: + return None + + +def optimize( + model, tokenizer, + oracle, config, + additional_properties={} + ): + file = open(config["log_dir"], "w") + print("config", config) + # print("molecule generation arguments", config["generation_config"]) + pool = Pool(config["pool_size"], validation_perc=config["validation_perc"]) + + config["generation_config"]["temperature"] = config["generation_temperature"][0] + + if "rej-sample-v2" in config["strategy"]: + training_args = get_training_arguments(config["rej_sample_config"]) + effective_batch_size = config["rej_sample_config"]["gradient_accumulation_steps"] * config["rej_sample_config"]["train_batch_size"] + num_single_train_steps = config["rej_sample_config"]["num_train_epochs"] * ((1 - config["validation_perc"]) * config["pool_size"] / effective_batch_size) + max_num_trains = oracle.max_oracle_calls / (config["rej_sample_config"]["train_tol_level"] * config["num_gens_per_iter"]) + max_num_train_steps = int(max_num_trains * num_single_train_steps) + optimizer, lr_scheduler = get_optimizer_and_lr_scheduler(model, config["rej_sample_config"], max_num_train_steps) + max_score = 0 + tol_level = 0 + num_iter = 0 + prev_train_iter = 0 + while True: + model.eval() + new_best_molecule_generated = False + iter_unique_optim_entries: List[OptimEntry] = {} + while len(iter_unique_optim_entries) < config["num_gens_per_iter"]: + optim_entries = create_optimization_entries( + config["generation_batch_size"], pool, + config=config + ) + for i in range(len(optim_entries)): + last_entry = MoleculeEntry(smiles="") + last_entry.similar_mol_entries = create_similar_mol_entries( + pool, last_entry, config["num_similars"] + ) + for prop_name, prop_spec in additional_properties.items(): + last_entry.add_props[prop_name] = prop_spec + optim_entries[i].last_entry = last_entry + + prompts = [ + optim_entry.to_prompt( + is_generation=True, include_oracle_score=prev_train_iter != 0, + config=config, max_score=max_score + ) + for optim_entry in optim_entries + ] + output_texts = [] + data = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + if type(model) == OPTForCausalLM: + del data["token_type_ids"] + for key, value in data.items(): + data[key] = value[:, -2048 + config["generation_config"]["max_new_tokens"]:] + output = model.generate( + **data, + **config["generation_config"] + ) + gc.collect() + torch.cuda.empty_cache() + output_texts.extend(tokenizer.batch_decode(output)) + + current_unique_optim_entries = {} + # with multiprocessing.Pool(processes=config["num_processes"]) as pol: + for i, molecule in enumerate(map(create_molecule_entry, output_texts)): + if molecule and not optim_entries[i].contains_entry(molecule): + if molecule.smiles not in oracle.mol_buffer and molecule.smiles not in current_unique_optim_entries: + molecule.similar_mol_entries = optim_entries[i].last_entry.similar_mol_entries + for prop_name, prop_spec in additional_properties.items(): + molecule.add_props[prop_name] = prop_spec + molecule.add_props[prop_name]["value"] = molecule.add_props[prop_name]["calculate_value"](molecule) + optim_entries[i].last_entry = molecule + current_unique_optim_entries[molecule.smiles] = optim_entries[i] + + num_of_molecules_to_score = min(len(current_unique_optim_entries), config["num_gens_per_iter"] - len(iter_unique_optim_entries)) + current_unique_smiles_list = list(current_unique_optim_entries.keys())[:num_of_molecules_to_score] + current_unique_optim_entries = {smiles: current_unique_optim_entries[smiles] for smiles in current_unique_smiles_list} + + if getattr(oracle, "takes_entry", False): + oracle_scores = oracle([current_unique_optim_entries[smiles].last_entry for smiles in current_unique_smiles_list]) + else: + oracle_scores = oracle(current_unique_smiles_list) + + for smiles, oracle_score in zip(current_unique_smiles_list, oracle_scores): + current_unique_optim_entries[smiles].last_entry.score = oracle_score + iter_unique_optim_entries[smiles] = current_unique_optim_entries[smiles] + file.write(f"generated smiles: {smiles}, score: {current_unique_optim_entries[smiles].last_entry.score:.4f}\n") + if max_score >= config["max_possible_oracle_score"] - 1e-2 or current_unique_optim_entries[smiles].last_entry.score > max_score: + max_score = max(max_score, current_unique_optim_entries[smiles].last_entry.score) + new_best_molecule_generated = True + + print(f"Iter unique optim entries: {len(iter_unique_optim_entries)}, budget: {len(oracle)}") + + if oracle.finish: + break + + if oracle.finish: + break + initial_num_iter = num_iter + num_iter = len(oracle.mol_buffer) // config["num_gens_per_iter"] + if num_iter > initial_num_iter: + tol_level += 1 + + if new_best_molecule_generated: + tol_level = 0 + + print(f"num_iter: {num_iter}, tol_level: {tol_level}, prev_train_iter: {prev_train_iter}") + if num_iter > initial_num_iter: + config["generation_config"]["temperature"] += config["num_gens_per_iter"] / (oracle.budget - config["num_gens_per_iter"]) * (config["generation_temperature"][1] - config["generation_temperature"][0]) + print(f"Generation temperature: {config['generation_config']['temperature']}") + + # diversity_score = 1 / (1 + math.log(1 + repeated_max_score) / math.log(10)) + pool.add(list(iter_unique_optim_entries.values())) + file.write("Pool\n") + for i, optim_entry in enumerate(pool.optim_entries): + file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") + + if "rej-sample-v2" in config["strategy"]: + # round_entries.extend(current_entries) + # round_entries = list(np.unique(round_entries))[::-1] + # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) + # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: + if tol_level >= config["rej_sample_config"]["train_tol_level"]: + train_entries, validation_entries = pool.get_train_valid_entries() + print(f"Num of training examples: {len(train_entries)}, num of validation examples: {len(validation_entries)}.") + file.write("Training entries\n") + for i, optim_entry in enumerate(train_entries): + file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") + file.write("Validation entries\n") + for i, optim_entry in enumerate(validation_entries): + file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") + + train_dataset = Dataset.from_dict({ + "sample": [ + optim_entry.to_prompt( + is_generation=False, include_oracle_score=True, + config=config, max_score=config["max_possible_oracle_score"] + ) + for optim_entry in train_entries + ] + }) + validation_dataset = Dataset.from_dict({ + "sample": [ + optim_entry.to_prompt( + is_generation=False, include_oracle_score=True, + config=config, max_score=config["max_possible_oracle_score"] + ) + for optim_entry in validation_entries + ] + }) + train_dataset.shuffle(seed=42) + validation_dataset.shuffle(seed=42) + + # early_stopping_callback = CustomEarlyStopCallback( + # early_stopping_patience=1, + # early_stopping_threshold=0.0001 + # ) + model_selection_callback = CustomModelSelectionCallback() + + model.train() + trainer = SFTTrainer( + model=model, + train_dataset=train_dataset, + eval_dataset=validation_dataset, + formatting_func=lambda x: x["sample"], + args=training_args, + packing=config["rej_sample_config"]["packing"], + tokenizer=tokenizer, + max_seq_length=config["rej_sample_config"]["max_seq_length"], + # data_collator=collator, + callbacks=[model_selection_callback], + optimizers=[optimizer, lr_scheduler], + ) + trainer.train() + print(f"Loading the best model state dict with validation loss {model_selection_callback.best_validation_loss}") + model.load_state_dict(model_selection_callback.best_model_state_dict) + del model_selection_callback.best_model_state_dict + gc.collect() + torch.cuda.empty_cache() + tol_level = 0 + prev_train_iter = num_iter \ No newline at end of file diff --git a/chemlactica/mol_opt/oracle_estimators.py b/chemlactica/mol_opt/oracle_estimators.py new file mode 100644 index 0000000..3d64477 --- /dev/null +++ b/chemlactica/mol_opt/oracle_estimators.py @@ -0,0 +1,96 @@ +from typing import List +import time +from trl import SFTTrainer, DataCollatorForCompletionOnlyLM +from transformers import AutoModelForCausalLM, PreTrainedModel, AutoModel, AutoConfig, AutoTokenizer +import torch +import torch.nn as nn +import numpy as np +from chemlactica.mol_opt.utils import MoleculeEntry +from sklearn.linear_model import Ridge + + +def find_second_eos_token_indices(sequences, eos_token_id): + return torch.where(sequences[:, 1:] == eos_token_id) + + +def init_linear_layer(layer, emb_length): + torch.nn.init.normal_( + layer.weight, + mean=0.0, std=1 / np.sqrt(emb_length + 1) + ) + torch.nn.init.constant_(layer.bias, val=0.0) + return layer + + +class ScalarHeadLM(PreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.config = config + self.lm_backbone = AutoModel.from_pretrained( + config._name_or_path, + config=config + ) + self.scalar_head = nn.Linear(config.hidden_size, 1) + init_linear_layer(self.scalar_head) + + def forward(self, **kwargs): + output = self.lm_backbone(**kwargs) + return self.scalar_head(output.last_hidden_state) + + +class LinearFingerprintModel: + + def __init__(self): + self.emb_length = 2048 + self.linear = Ridge() + self.all_entries = [] + self.is_fit = False + + def __call__(self, mol_entries: List[MoleculeEntry]): + mol_embs = np.array([entry.fingerprint for entry in mol_entries]) + return self.linear.predict(mol_embs) + + def fit(self, mol_entries: List[MoleculeEntry]): + self.is_fit = True + start_time = time.time() + self.all_entries.extend(mol_entries) + mol_embs = np.array([entry.fingerprint for entry in self.all_entries]) + scores = np.array([entry.score for entry in self.all_entries]) + self.linear.fit(mol_embs, scores) + print(f"Fit time {time.time() - start_time:.4f}s") + + +class ScalarOracleApproximator: + + def __init__(self, config, tokenizer): + self.scalar_head_lm = ScalarHeadLM(config) + self.tokenizer = tokenizer + + def __call__(self, mol_entries): + prompts = [f"[START_SMILES]{e.smiles}[END_SMILES]" for e in mol_entries] + data = self.tokenizer(prompts, return_tensors="pt", padding=True).to(self.scalar_head_lm.device) + del data["token_type_ids"] + outputs = self.scalar_head_lm( + **data + ) + print(outputs) + + +class SFTOracleApproximator: + + def __init__(self, config, tokenizer, device): + self.ml = AutoModelForCausalLM.from_pretrained( + config._name_or_path, + config=config + ).to(device) + self.tokenizer = tokenizer + + +if __name__ == "__main__": + config = AutoConfig.from_pretrained("/nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/26d322857a184fcbafda5d4a/checkpoint-118784") + tokenizer = AutoTokenizer.from_pretrained("chemlactica/tokenizer/ChemLacticaTokenizer66", padding_side="left") + scalar_oracle_approx = ScalarOracleApproximator(config, tokenizer) + + mol_entries = [MoleculeEntry("CCC" + i * "C") for i in range(10)] + scalar_oracle_approx(mol_entries) \ No newline at end of file diff --git a/chemlactica/mol_opt/slurm_hparam_search.py b/chemlactica/mol_opt/slurm_hparam_search.py new file mode 100644 index 0000000..eaa775b --- /dev/null +++ b/chemlactica/mol_opt/slurm_hparam_search.py @@ -0,0 +1,79 @@ +import submitit +import subprocess +import itertools as it +import datetime +import yaml +import os +import copy + + +def create_hparam_configs(config_file_path): + config_tune = yaml.safe_load(open("hparams_tune.yaml")) + config_merged = {} + for key, value in config_tune["parameters"].items(): + if type(value) == list: + config_merged[key] = value + else: + for k, v in value.items(): + config_merged[key+'+'+k] = v + + config_default = yaml.safe_load(open(config_file_path)) + hparam_names = list(config_merged.keys()) + all_configs = [] + for params in it.product(*config_merged.values()): + # pprint(params) + # pprint(hparam_names) + config = copy.deepcopy(config_default) + for i, p in enumerate(params): + if '+' in hparam_names[i]: + a, b = hparam_names[i].split("+") + config[a][b] = p + else: + config[hparam_names[i]] = p + # pprint(params) + # pprint(config) + all_configs.append(config) + # print(config) + return all_configs + + +if __name__ == "__main__": + n_runs = 3 + + config_file_path = "chemlactica_125m_hparams.yaml" + # config_file_path = "main/chemlactica/chemma_2b_hparams.yaml" + hparam_configs = create_hparam_configs(config_file_path) + # infer_config = [yaml.safe_load(open(config_file_path))] + model_name = "-".join(config_file_path.split("/")[-1].split("_")[:2]) + + executor = submitit.AutoExecutor(folder="/auto/home/tigranfahradyan/slurm_jobs/PMO/job_%j") + executor.update_parameters( + name="chemlactica-pmo", timeout_min=n_runs * 3 * 60, + gpus_per_node=1, nodes=1, mem_gb=50, cpus_per_task=8, + slurm_array_parallelism=10 + ) + jobs = [] + with executor.batch(): + for config in hparam_configs[:1]: + formatted_date_time = datetime.datetime.now().strftime("%Y-%m-%d") + base = f"results/{formatted_date_time}" + os.makedirs(base, exist_ok=True) + v = 0 + name = model_name + "-" + "+".join(config["strategy"]) + while os.path.exists(os.path.join(base, f"{name}-{v}")): + v += 1 + output_dir = os.path.join(base, f"{name}-{v}") + output_dir += "tune" + # output_dir = "main/chemlactica/results/2024-05-11/chemlactica-125m-rej-sample-4" + os.makedirs(output_dir, exist_ok=True) + yaml.safe_dump(config, open(os.path.join(output_dir, "hparams.yaml"), "w")) + function = submitit.helpers.CommandFunction([ + 'python3', 'hparam_search.py', + '--config_default', os.path.join(output_dir, "hparams.yaml"), + '--output_dir', output_dir, + '--n_runs', str(n_runs), + ]) + print(' '.join(function.command)) + # subprocess.run(function.command) + job = executor.submit(function) + jobs.append(job) diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py new file mode 100644 index 0000000..1332f0e --- /dev/null +++ b/chemlactica/mol_opt/tunning.py @@ -0,0 +1,116 @@ +from transformers.trainer_callback import TrainerControl, TrainerState, TrainerCallback +from trl import SFTTrainer, DataCollatorForCompletionOnlyLM +from transformers import TrainingArguments, get_polynomial_decay_schedule_with_warmup, EarlyStoppingCallback +from torch.optim.lr_scheduler import ConstantLR +import torch +import math +import time +from collections import OrderedDict +from chemlactica.mol_opt.utils import generate_random_number + + +class CustomEarlyStopCallback(TrainerCallback): + + def __init__(self, early_stopping_patience: int, early_stopping_threshold: float) -> None: + super().__init__() + self.best_valid_loss = math.inf + self.early_stopping_patience = early_stopping_patience + self.current_patiance = 0 + self.early_stopping_threshold = early_stopping_threshold + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + self.best_valid_loss = math.inf + self.current_patiance = 0 + return super().on_train_begin(args, state, control, **kwargs) + + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs): + if metrics["eval_loss"] >= self.best_valid_loss - self.early_stopping_threshold: + self.current_patiance += 1 + else: + self.current_patiance = 0 + self.best_valid_loss = metrics["eval_loss"] + print(f"Early Stopping patiance: {self.current_patiance}/{self.early_stopping_patience}") + if self.current_patiance >= self.early_stopping_patience: + control.should_training_stop = True + return super().on_evaluate(args, state, control, **kwargs) + + +class CustomModelSelectionCallback(TrainerCallback): + + def __init__(self): + super().__init__() + self.best_validation_loss: float = math.inf + self.best_model_state_dict: OrderedDict = OrderedDict() + + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, model, **kwargs): + if metrics["eval_loss"] <= self.best_validation_loss: + self.best_validation_loss = metrics["eval_loss"] + print(f"Better validation loss achieved {self.best_validation_loss}, updating the state dict.") + for key, value in model.state_dict().items(): + self.best_model_state_dict[key] = value.detach().clone() + return super().on_evaluate(args, state, control, **kwargs) + + +# class CustomSFTTrainer(SFTTrainer): +# + # def __init__(self, *args, patience, toll, **kwargs): + # super().__init__(*args, **kwargs) + # self.patience = patience + # self.initial_pat = patience + # self.toll = toll + # self.best_loss = math.inf + + # def log(self, logs) -> None: + # if logs.get("loss"): + # curr_loss = logs["loss"] + # if curr_loss > self.best_loss - self.toll: + # self.patience -= 1 + # print(f"loss did not improve, patience {self.patience}") + # else: + # print("loss improved") + # self.best_loss = curr_loss + # self.patience = self.initial_pat + # if self.patience == 0: + # print("The loss does not improve, stop training.") + # self.control.should_training_stop = True + # return super().log(logs) + + +def get_training_arguments(config): + checkpoints_dir = f"{config['checkpoints_dir']}/checkpoint-{time.time():.4f}" + return TrainingArguments( + output_dir=checkpoints_dir, + per_device_train_batch_size=config["train_batch_size"], + per_device_eval_batch_size=config["train_batch_size"], + max_grad_norm=config["global_gradient_norm"], + num_train_epochs=config["num_train_epochs"], + evaluation_strategy="epoch", + # save_strategy="epoch", + save_strategy="no", + dataloader_drop_last=False, + dataloader_pin_memory=True, + dataloader_num_workers=config["dataloader_num_workers"], + gradient_accumulation_steps=config["gradient_accumulation_steps"], + logging_steps=1, + save_safetensors=False, + metric_for_best_model="loss", + # load_best_model_at_end=True, + # save_total_limit=1 + ) + + +def get_optimizer_and_lr_scheduler(model, config, max_train_steps): + optimizer = torch.optim.AdamW( + model.parameters(), + lr=config["max_learning_rate"], + betas=[config["adam_beta1"], config["adam_beta2"]], + weight_decay=config["weight_decay"], + ) + lr_scheduler = get_polynomial_decay_schedule_with_warmup( + optimizer, + num_warmup_steps=config["warmup_steps"], + num_training_steps=max_train_steps, + lr_end=config["lr_end"], + power=1.0, + ) + return optimizer, lr_scheduler \ No newline at end of file diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py new file mode 100644 index 0000000..ec6d2af --- /dev/null +++ b/chemlactica/mol_opt/utils.py @@ -0,0 +1,338 @@ +from typing import List +import datetime +import os +import random +from pathlib import Path +import numpy as np +import torch +from chemlactica.mol_opt.metrics import top_auc +from rdkit import Chem, DataStructs, RDLogger +from rdkit.Chem import AllChem, MACCSkeys, rdMolDescriptors + +# Disable RDKit logs +RDLogger.DisableLog("rdApp.*") + + +def set_seed(seed_value): + random.seed(seed_value) + # Set seed for NumPy + np.random.seed(seed_value) + # Set seed for PyTorch + torch.manual_seed(seed_value) + + +def get_short_name_for_ckpt_path(chpt_path: str, hash_len: int = 6): + get_short_name_for_ckpt_path = Path(chpt_path) + return ( + get_short_name_for_ckpt_path.parent.name[:hash_len] + + "-" + + get_short_name_for_ckpt_path.name.split("-")[-1] + ) + + +def get_morgan_fingerprint(mol): + return AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) + + +def get_maccs_fingerprint(mol): + return MACCSkeys.GenMACCSKeys(mol) + + +def tanimoto_dist_func(fing1, fing2, fingerprint: str = "morgan"): + return DataStructs.TanimotoSimilarity( + fing1 if fingerprint == "morgan" else fing1, + fing2 if fingerprint == "morgan" else fing2, + ) + + +def generate_random_number(lower, upper): + return lower + random.random() * (upper - lower) + + +def canonicalize(smiles): + mol = Chem.MolFromSmiles(smiles) + return Chem.MolToSmiles(mol, canonical=True) + # return Chem.MolToSmiles(Chem.MolFromSmiles(smiles), kekuleSmiles=True) + + +class MoleculeEntry: + def __init__(self, smiles, score=0, **kwargs): + self.smiles = smiles + self.score = score + self.similar_mol_entries = [] + if smiles: + self.smiles = canonicalize(smiles) + self.mol = Chem.MolFromSmiles(smiles) + self.fingerprint = get_morgan_fingerprint(self.mol) + self.add_props = kwargs + + def __eq__(self, other): + return self.smiles == other.smiles + + def __lt__(self, other): + if self.score == other.score: + return self.smiles < other.smiles + return self.score < other.score + + def __str__(self): + return ( + f"smiles: {self.smiles}, " + f"score: {round(self.score, 4) if self.score != None else 'none'}" + ) + + def __repr__(self): + return str(self) + def __hash__(self): + return hash(self.smiles) + + +class ConstraedTPSAOracle: + def __init__(self, max_oracle_calls: int): + self.max_oracle_calls = max_oracle_calls + self.freq_log = 100 + self.mol_buffer = {} + self.max_possible_oracle_score = 1.0 + self.takes_entry = True + + def __call__(self, molecules): + oracle_scores = [] + for molecule in molecules: + if self.mol_buffer.get(molecule.smiles): + oracle_scores.append(sum(self.mol_buffer[molecule.smiles][0])) + else: + try: + tpsa = rdMolDescriptors.CalcTPSA(molecule.mol) + tpsa_score = min(tpsa / 1000, 1) + weight = rdMolDescriptors.CalcExactMolWt(molecule.mol) + if weight <= 349: + weight_score = 1 + elif weight >= 500: + weight_score = 0 + else: + weight_score = -0.00662 * weight + 3.31125 + + num_rings = rdMolDescriptors.CalcNumRings(molecule.mol) + if num_rings >= 2: + num_rights_score = 1 + else: + num_rights_score = 0 + # print(tpsa_score, weight_score, num_rights_score) + oracle_score = (tpsa_score + weight_score + num_rights_score) / 3 + except Exception as e: + print(e) + oracle_score = 0 + self.mol_buffer[molecule.smiles] = [oracle_score, len(self.mol_buffer) + 1] + if len(self.mol_buffer) % 100 == 0: + self.log_intermediate() + oracle_scores.append(oracle_score) + return oracle_scores + + def log_intermediate(self): + scores = [v[0] for v in self.mol_buffer.values()] + scores_sorted = sorted(scores, reverse=True)[:100] + n_calls = len(self.mol_buffer) + + score_avg_top1 = np.max(scores_sorted) + score_avg_top10 = np.mean(scores_sorted[:10]) + score_avg_top100 = np.mean(scores_sorted) + + print(f"{n_calls}/{self.max_oracle_calls} | ", + f"auc_top1: {top_auc(self.mol_buffer, 1, False, self.freq_log, self.max_oracle_calls)} | ", + f"auc_top10: {top_auc(self.mol_buffer, 10, False, self.freq_log, self.max_oracle_calls)} | ", + f"auc_top100: {top_auc(self.mol_buffer, 100, False, self.freq_log, self.max_oracle_calls)}") + + print(f'avg_top1: {score_avg_top1:.3f} | ' + f'avg_top10: {score_avg_top10:.3f} | ' + f'avg_top100: {score_avg_top100:.3f}') + + def __len__(self): + return len(self.mol_buffer) + + @property + def budget(self): + return self.max_oracle_calls + + @property + def finish(self): + return len(self.mol_buffer) >= self.max_oracle_calls + + + +class Pool: + def __init__(self, size, validation_perc: float): + self.size = size + self.optim_entries: List[OptimEntry] = [] + self.num_validation_entries = int(size * validation_perc + 1) + + # def random_dump(self, num): + # for _ in range(num): + # rand_ind = random.randint(0, num - 1) + # self.molecule_entries.pop(rand_ind) + # print(f"Dump {num} random elements from pool, num pool mols {len(self)}") + + def add(self, entries: List, diversity_score=1.0): + assert type(entries) == list + self.optim_entries.extend(entries) + self.optim_entries.sort(key=lambda x: x.last_entry, reverse=True) + + # remove doublicates + new_optim_entries = [] + for entry in self.optim_entries: + insert = True + for e in new_optim_entries: + if ( + entry.last_entry == e.last_entry + or tanimoto_dist_func( + entry.last_entry.fingerprint, e.last_entry.fingerprint + ) + > diversity_score + ): + insert = False + break + if insert: + new_optim_entries.append(entry) + + self.optim_entries = new_optim_entries[: min(len(new_optim_entries), self.size)] + curr_num_validation_entries = sum([entry.entry_status == EntryStatus.valid for entry in self.optim_entries]) + + i = 0 + while curr_num_validation_entries < self.num_validation_entries: + if self.optim_entries[i].entry_status == EntryStatus.none: + self.optim_entries[i].entry_status = EntryStatus.valid + curr_num_validation_entries += 1 + i += 1 + + for j in range(i, len(self.optim_entries)): + if self.optim_entries[j].entry_status == EntryStatus.none: + self.optim_entries[j].entry_status = EntryStatus.train + + curr_num_validation_entries = sum([entry.entry_status == EntryStatus.valid for entry in self.optim_entries]) + assert curr_num_validation_entries == min(len(self.optim_entries), self.num_validation_entries) + + def get_train_valid_entries(self): + train_entries = [] + valid_entries = [] + for entry in self.optim_entries: + if entry.entry_status == EntryStatus.train: + train_entries.append(entry) + elif entry.entry_status == EntryStatus.valid: + valid_entries.append(entry) + else: + raise Exception(f"EntryStatus of an entry in pool cannot be {entry.entry_status}.") + assert min(len(self.optim_entries), self.num_validation_entries) == len(valid_entries) + return train_entries, valid_entries + + def random_subset(self, subset_size): + # rand_inds = np.random.permutation(min(len(self.optim_entries), subset_size)) + rand_inds = np.random.permutation(len(self.optim_entries)) + rand_inds = rand_inds[:subset_size] + return [self.optim_entries[i] for i in rand_inds] + + def __len__(self): + return len(self.optim_entries) + + +def make_output_files_base(input_path, results_dir, run_name, config): + formatted_date_time = datetime.datetime.now().strftime("%Y-%m-%d") + base = os.path.join(input_path, run_name, formatted_date_time) + os.makedirs(base, exist_ok=True) + v = 0 + strategy = "+".join(config["strategy"]) + while os.path.exists(os.path.join(base, f"{strategy}-{v}")): + v += 1 + output_dir = os.path.join(base, f"{strategy}-{v}") + os.makedirs(output_dir, exist_ok=True) + return output_dir + + +def create_prompt_with_similars(mol_entry: MoleculeEntry, sim_range=None): + prompt = "" + for sim_mol_entry in mol_entry.similar_mol_entries: + if sim_range: + prompt += f"[SIMILAR]{sim_mol_entry.smiles} {generate_random_number(sim_range[0], sim_range[1]):.2f}[/SIMILAR]" # noqa + else: + prompt += f"[SIMILAR]{sim_mol_entry.smiles} {tanimoto_dist_func(sim_mol_entry.fingerprint, mol_entry.fingerprint):.2f}[/SIMILAR]" # noqa + return prompt + + +class EntryStatus: + none = 0 + train = 1 + valid = 2 + + +class OptimEntry: + def __init__(self, last_entry, mol_entries): + self.last_entry: MoleculeEntry = last_entry + self.mol_entries: List[MoleculeEntry] = mol_entries + self.entry_status: EntryStatus = EntryStatus.none + + def to_prompt( + self, is_generation: bool, + include_oracle_score: bool, config, + max_score + ): + prompt = "" + # prompt = config["eos_token"] + for mol_entry in self.mol_entries: + prompt += config["eos_token"] + prompt += create_prompt_with_similars(mol_entry=mol_entry) + + for prop_name, prop_spec in mol_entry.add_props.items(): + prompt += f"{prop_spec['start_tag']}{prop_spec['value']}{prop_spec['end_tag']}" + + if "default" in config["strategy"]: + pass + elif "rej-sample-v2" in config["strategy"]: + if include_oracle_score: + prompt += f"[PROPERTY]oracle_score {mol_entry.score:.2f}[/PROPERTY]" + else: + raise Exception(f"Strategy {config['strategy']} not known.") + prompt += f"[START_SMILES]{mol_entry.smiles}[END_SMILES]" + + assert self.last_entry + prompt += config["eos_token"] + if is_generation: + prompt_with_similars = create_prompt_with_similars( + self.last_entry, sim_range=config["sim_range"] + ) + else: + prompt_with_similars = create_prompt_with_similars(self.last_entry) + + prompt += prompt_with_similars + + for prop_name, prop_spec in self.last_entry.add_props.items(): + prompt += prop_spec["start_tag"] + prop_spec["infer_value"](self.last_entry) + prop_spec["end_tag"] + + if "default" in config["strategy"]: + pass + elif "rej-sample-v2" in config["strategy"]: + if is_generation: + desired_oracle_score = generate_random_number( + max_score, config["max_possible_oracle_score"] + ) + oracle_score = desired_oracle_score + else: + oracle_score = self.last_entry.score + if include_oracle_score: + prompt += f"[PROPERTY]oracle_score {oracle_score:.2f}[/PROPERTY]" + else: + raise Exception(f"Strategy {config['strategy']} not known.") + + if is_generation: + prompt += "[START_SMILES]" + else: + prompt += f"[START_SMILES]{self.last_entry.smiles}[END_SMILES]" + prompt += config["eos_token"] + + return prompt + + def contains_entry(self, mol_entry: MoleculeEntry): + for entry in self.mol_entries: + if mol_entry == entry: + return True + for sim_entry in entry.similar_mol_entries: + if mol_entry == sim_entry: + return True + + return False diff --git a/chemlactica/utils/logits_processors.py b/chemlactica/utils/logits_processors.py new file mode 100644 index 0000000..4abe0cd --- /dev/null +++ b/chemlactica/utils/logits_processors.py @@ -0,0 +1,88 @@ +from transformers import LogitsProcessor +import torch +from typing import List, Union + + +class OneOccurenceLogitsProcessor(LogitsProcessor): + def __init__(self, suppress_tokens): + self.suppress_tokens = list(suppress_tokens) + + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + for token_id in self.suppress_tokens: + if token_id in input_ids: + scores[:, token_id] = -float("inf") + return scores + + +class TunableExponentialDecayLengthPenalty(LogitsProcessor): + def __init__( + self, + exponential_decay_factors: List[float], + regulation_starts: List[int], + decay_token_ids: Union[int, List[int]], + input_ids_seq_length: int, + ): + self.regulation_starts = regulation_starts + self.regulation_list = exponential_decay_factors + if isinstance(decay_token_ids, int): + decay_token_ids = [decay_token_ids] + self.decay_token_ids = decay_token_ids + + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + cur_len = input_ids.shape[-1] + penalties = torch.zeros_like(scores) + scores_processed = scores + for token_id, regulation_factor, regulation_start in zip( + self.decay_token_ids, self.regulation_list, self.regulation_starts + ): + if cur_len > regulation_start: + penalty_idx = cur_len - regulation_start + penalty = torch.abs(scores[:, token_id]) * ( + pow(regulation_factor, penalty_idx) - 1 + ) + penalties[:, token_id] = penalty + scores_processed = scores + penalties + return scores_processed + + +class SequentialDecayProcessor(LogitsProcessor): + def __init__( + self, + exponential_decay_factors: List[float], + decay_token_ids: Union[int, List[int]], + input_ids_seq_length: int, + ): + self.regulation_list = exponential_decay_factors + if isinstance(decay_token_ids, int): + decay_token_ids = [decay_token_ids] + + self.decay_token_ids = decay_token_ids + self.regulation_starts = 99999999999999999 * len(self.decay_token_ids) + self.regulation_starts[0] = 0 + self.regulation_list = exponential_decay_factors + + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + cur_len = input_ids.shape[-1] + penalties = torch.zeros_like(scores) + scores_processed = scores + for i, (token_id, regulation_factor) in enumerate( + zip(self.decay_token_ids, self.regulation_list) + ): + if token_id in input_ids: + self.regulation_starts[i] = cur_len + + reg_start = self.regulation_starts[i] + if cur_len > reg_start: + penalty_idx = cur_len - reg_start + penalty = torch.abs(scores[:, token_id]) * ( + pow(regulation_factor, penalty_idx) - 1 + ) + penalties[:, token_id] = penalty + scores_processed = scores + penalties + return scores_processed diff --git a/chemlactica/utils/logits_utils.py b/chemlactica/utils/logits_utils.py new file mode 100644 index 0000000..21de004 --- /dev/null +++ b/chemlactica/utils/logits_utils.py @@ -0,0 +1,79 @@ +import yaml +import os +from typing import List, Any, Dict +from transformers.generation import LogitsProcessor, LogitsProcessorList +from dataclasses import dataclass, field +import importlib +import importlib.util + + +def import_local_module(module_name, relative_path): + current_dir = os.path.dirname(os.path.abspath(__file__)) + module_path = os.path.join(current_dir, relative_path) + + spec = importlib.util.spec_from_file_location(module_name, module_path) + if spec is None: + raise ImportError(f"Could not import module {module_name} from {module_path}") + + # Create a new module based on the spec + module = importlib.util.module_from_spec(spec) + + # Execute the module's code and populate the module + spec.loader.exec_module(module) + + return module + + +@dataclass +class LogitsProcessorConfig: + class_name: str + is_local: str + module: str + kwargs: Dict[str, Any] + path: str = field(default=None) + + +def instantiate_processors( + config: List[LogitsProcessorConfig], +) -> List[LogitsProcessor]: + processors = [] + for processor_config in config: + if processor_config.is_local: + module = import_local_module(processor_config.module, processor_config.path) + else: + module = importlib.import_module(processor_config.module) + processor_class = getattr(module, processor_config.class_name) + processor = processor_class(**processor_config.kwargs) + processors.append(processor) + return processors + + +def load_processor_config(file_path: str) -> List[LogitsProcessorConfig]: + with open(file_path, "r") as file: + config_data = yaml.safe_load(file) + configs = [ + LogitsProcessorConfig(**processor) + for processor in config_data["logits_processors"] + ] + return configs + + +def get_logits_processors(logits_processors_config_path=None): + # current_dir = os.path.dirname(os.path.abspath(__file__)) + # config_file_path = os.path.join(current_dir, "logit_configs","best_config.yaml") + if logits_processors_config_path: + logit_processors_config = load_processor_config(logits_processors_config_path) + logit_processors = instantiate_processors(logit_processors_config) + logit_processor_list = LogitsProcessorList(logit_processors) + return logit_processor_list + else: + return None + + +if __name__ == "__main__": + current_dir = os.path.dirname(os.path.abspath(__file__)) + config_file_path = os.path.join(current_dir, "logit_configs", "best_config.yaml") + logit_processors_config = load_processor_config(config_file_path) + logit_processors = instantiate_processors(logit_processors_config) + logit_processor_list = LogitsProcessorList(logit_processors) + print(logit_processor_list)