Skip to content

Commit

Permalink
don't give oracle score tag before the first training
Browse files Browse the repository at this point in the history
  • Loading branch information
tigranfah committed May 18, 2024
1 parent 20d3fe0 commit bfa32d7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
9 changes: 6 additions & 3 deletions chemlactica/mol_opt/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def optimize(
)
optim_entries[i].last_entry = last_entry

prompts = [optim_entry.to_prompt(is_generation=True, config=config) for optim_entry in optim_entries]
prompts = [
optim_entry.to_prompt(is_generation=True, include_oracle_score=prev_train_iter != 0, config=config)
for optim_entry in optim_entries
]

output_texts = []
for i in range(0, len(prompts), config["generation_batch_size"]):
Expand Down Expand Up @@ -154,7 +157,7 @@ def optimize(
# round_entries = list(np.unique(round_entries))[::-1]
# top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"])
# if top_k >= config["rej_sample_config"]["num_samples_per_round"]:
if config["rej_sample_config"]["train_condition"](num_iter, tol_level, prev_train_iter):
if config["rej_sample_config"]["should_train"](num_iter, tol_level, prev_train_iter):
training_entries = pool.optim_entries
print(f"Num of train examples {len(training_entries)}.")
file.write("Training entries\n")
Expand All @@ -163,7 +166,7 @@ def optimize(

train_dataset = Dataset.from_dict({
"sample": [
optim_entry.to_prompt(is_generation=False, config=config)
optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config)
for optim_entry in training_entries
]
})
Expand Down
13 changes: 8 additions & 5 deletions chemlactica/mol_opt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,21 +156,23 @@ def __init__(self, last_entry, mol_entries):
self.last_entry: MoleculeEntry = last_entry
self.mol_entries: List[MoleculeEntry] = mol_entries

def to_prompt(self, is_generation, config):
def to_prompt(self, is_generation: bool, include_oracle_score: bool, config):
prompt = ""
prompt = config["eos_token"]
for mol_entry in self.mol_entries:
prompt += config["eos_token"]
# prompt += config["eos_token"]
if "default" in config["strategy"]:
prompt += create_prompt_with_similars(mol_entry=mol_entry)
elif "rej-sample-v2" in config["strategy"]:
prompt += create_prompt_with_similars(mol_entry=mol_entry)
prompt += f"[PROPERTY]oracle_score {mol_entry.score:.2f}[/PROPERTY]"
if include_oracle_score:
prompt += f"[ORACLE_SCORE]{mol_entry.score:.2f}[/ORACLE_SCORE]"
else:
raise Exception(f"Strategy {config['strategy']} not known.")
prompt += f"[START_SMILES]{mol_entry.smiles}[END_SMILES]"

assert self.last_entry
prompt += config["eos_token"]
# prompt += config["eos_token"]
if is_generation:
prompt_with_similars = create_prompt_with_similars(
self.last_entry, sim_range=config["sim_range"]
Expand All @@ -195,7 +197,8 @@ def to_prompt(self, is_generation, config):
oracle_score = desired_oracle_score
else:
oracle_score = self.last_entry.score
prompt += f"[PROPERTY]oracle_score {oracle_score:.2f}[/PROPERTY]"
if include_oracle_score:
prompt += f"[ORACLE_SCORE]{oracle_score:.2f}[/ORACLE_SCORE]"
else:
raise Exception(f"Strategy {config['strategy']} not known.")

Expand Down

0 comments on commit bfa32d7

Please sign in to comment.