diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 50fbad7..ef14b0a 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -110,7 +110,7 @@ def optimize( current_optim_entries = [] with multiprocessing.Pool(processes=config["num_processes"]) as pol: for i, entry in enumerate(pol.map(create_molecule_entry, output_texts)): - if entry: + if entry and not optim_entries[i].contains_entry(entry): current_mol_entries.append(entry) current_optim_entries.append(optim_entries[i]) @@ -130,7 +130,6 @@ def optimize( tol_level = 0 if oracle.finish or len(iter_optim_entries) >= config["num_gens_per_iter"]: break - if oracle.finish: break @@ -152,7 +151,7 @@ def optimize( # round_entries = list(np.unique(round_entries))[::-1] # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: - if num_iter % 3 == 0 and num_iter > initial_num_iter: + if num_iter % 5 == 0 and num_iter > initial_num_iter: training_entries = pool.optim_entries print(f"Num of train examples {len(training_entries)}.") file.write("Training entries\n") diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 98bdfb0..1ee6e16 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -194,4 +194,10 @@ def to_prompt(self, is_generation, config): else: prompt += f"[START_SMILES]{self.last_entry.smiles}[END_SMILES]" - return prompt \ No newline at end of file + return prompt + + def contains_entry(self, mol_entry: MoleculeEntry): + for entry in self.mol_entries: + if mol_entry == entry: + return True + return False \ No newline at end of file