Skip to content

Commit

Permalink
add validating function for smiles
Browse files Browse the repository at this point in the history
  • Loading branch information
tigranfah committed Aug 14, 2024
1 parent f758645 commit 0f75cc0
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
4 changes: 2 additions & 2 deletions chemlactica/mol_opt/chemlactica_125m_hparams.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
checkpoint_path: /path/to/model_dir
tokenizer_path: /path/to/tokenizer_dir
checkpoint_path: yerevann/chemlactica-125m
tokenizer_path: yerevann/chemlactica-125m
pool_size: 10
validation_perc: 0.2
num_mols: 0
Expand Down
20 changes: 9 additions & 11 deletions chemlactica/mol_opt/example_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, max_oracle_calls: int):
self.mol_buffer = {}

# the maximum possible oracle score or an upper bound
self.max_possible_oracle_score = 1.0
self.max_possible_oracle_score = 800

# if True the __call__ function takes list of MoleculeEntry objects
# if False (or unspecified) the __call__ function takes list of SMILES strings
Expand All @@ -39,16 +39,14 @@ def __call__(self, molecules: List[MoleculeEntry]):
else:
try:
tpsa = rdMolDescriptors.CalcTPSA(molecule.mol)
tpsa_score = min(tpsa / 1000, 1)
oracle_score = tpsa
weight = rdMolDescriptors.CalcExactMolWt(molecule.mol)
if weight <= 349:
weight_score = 1
elif weight >= 500:
weight_score = 0
else:
weight_score = -0.00662 * weight + 3.31125

oracle_score = (tpsa_score + weight_score) / 3
num_rings = rdMolDescriptors.CalcNumRings(molecule.mol)
if weight >= 350:
oracle_score = 0
if num_rings < 2:
oracle_score = 0

except Exception as e:
print(e)
oracle_score = 0
Expand Down Expand Up @@ -105,7 +103,7 @@ def parse_arguments():
for i in range(args.n_runs):
set_seed(seeds[i])
oracle = TPSA_Weight_Oracle(max_oracle_calls=1000)
config["log_dir"] = os.path.join(args.output_dir, "results_tpsa+weight+num_rungs.log")
config["log_dir"] = os.path.join(args.output_dir, f"results_chemlactica_tpsa+weight+num_rungs_{seeds[i]}.log")
config["max_possible_oracle_score"] = oracle.max_possible_oracle_score
optimize(
model, tokenizer,
Expand Down
7 changes: 5 additions & 2 deletions chemlactica/mol_opt/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ def create_optimization_entries(num_entries, pool, config):
return optim_entries


def create_molecule_entry(output_text):
def create_molecule_entry(output_text, validate_smiles):
start_smiles_tag, end_smiles_tag = "[START_SMILES]", "[END_SMILES]"
start_ind = output_text.rfind(start_smiles_tag)
end_ind = output_text.rfind(end_smiles_tag)
if start_ind == -1 or end_ind == -1:
return None
generated_smiles = output_text[start_ind+len(start_smiles_tag):end_ind]
if not validate_smiles(generated_smiles):
return None
if len(generated_smiles) == 0:
return None

Expand All @@ -58,7 +60,8 @@ def create_molecule_entry(output_text):
def optimize(
model, tokenizer,
oracle, config,
additional_properties={}
additional_properties={},
validate_smiles=lambda x:True
):
file = open(config["log_dir"], "w")
print("config", config)
Expand Down

0 comments on commit 0f75cc0

Please sign in to comment.