Skip to content

Commit

Permalink
Merge pull request #38 from YerevaNN/mol_opt
Browse files Browse the repository at this point in the history
Mol opt
  • Loading branch information
tigranfah authored Aug 14, 2024
2 parents 949beb7 + 0f75cc0 commit 59d754a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 18 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Instructions coming soon...
### Molecular Optimization 🎯
Running the Optimization Algorithm requires two steps:

**Step 1.** Define the Oracle, which is responsible to evaluate the oracle score for the given molecules. Below is presented the Oracle implementation scheme.
**Step 1.** Define the Oracle, which is responsible to evaluate the oracle scores for the given molecules. Below is presented the Oracle implementation scheme.

```python
class ExampleOracle:
Expand Down Expand Up @@ -97,7 +97,7 @@ rej_sample_config:
... fine tuning hyperparameters ...
```
Calling the **optimize** function.
Call the **optimize** function.
```python
from chemlactica.mol_opt.optimization import optimize
Expand All @@ -119,7 +119,7 @@ optimize(
)
```

Refer to [example_run.py](https://github.com/YerevaNN/ChemLactica/blob/main/chemlactica/mol_opt/example_run.py) for a full working example of an optimization run. For more complex examples refer to the [ChemlacticaTestSuit]() repository [mol_opt]() and [retmol]() directories.
Refer to [example_run.py](https://github.com/YerevaNN/ChemLactica/blob/main/chemlactica/mol_opt/example_run.py) for a full working example of an optimization run. For more complex examples refer to the [ChemlacticaTestSuit]() repository [mol_opt](https://github.com/YerevaNN/ChemLacticaTestSuite/tree/master/mol_opt) and [retmol](https://github.com/YerevaNN/ChemLacticaTestSuite/tree/master/retmol) directories.

## Tests
The test for running the a small sized model with the same
Expand Down
4 changes: 2 additions & 2 deletions chemlactica/mol_opt/chemlactica_125m_hparams.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
checkpoint_path: /path/to/model_dir
tokenizer_path: /path/to/tokenizer_dir
checkpoint_path: yerevann/chemlactica-125m
tokenizer_path: yerevann/chemlactica-125m
pool_size: 10
validation_perc: 0.2
num_mols: 0
Expand Down
20 changes: 9 additions & 11 deletions chemlactica/mol_opt/example_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, max_oracle_calls: int):
self.mol_buffer = {}

# the maximum possible oracle score or an upper bound
self.max_possible_oracle_score = 1.0
self.max_possible_oracle_score = 800

# if True the __call__ function takes list of MoleculeEntry objects
# if False (or unspecified) the __call__ function takes list of SMILES strings
Expand All @@ -39,16 +39,14 @@ def __call__(self, molecules: List[MoleculeEntry]):
else:
try:
tpsa = rdMolDescriptors.CalcTPSA(molecule.mol)
tpsa_score = min(tpsa / 1000, 1)
oracle_score = tpsa
weight = rdMolDescriptors.CalcExactMolWt(molecule.mol)
if weight <= 349:
weight_score = 1
elif weight >= 500:
weight_score = 0
else:
weight_score = -0.00662 * weight + 3.31125

oracle_score = (tpsa_score + weight_score) / 3
num_rings = rdMolDescriptors.CalcNumRings(molecule.mol)
if weight >= 350:
oracle_score = 0
if num_rings < 2:
oracle_score = 0

except Exception as e:
print(e)
oracle_score = 0
Expand Down Expand Up @@ -105,7 +103,7 @@ def parse_arguments():
for i in range(args.n_runs):
set_seed(seeds[i])
oracle = TPSA_Weight_Oracle(max_oracle_calls=1000)
config["log_dir"] = os.path.join(args.output_dir, "results_tpsa+weight+num_rungs.log")
config["log_dir"] = os.path.join(args.output_dir, f"results_chemlactica_tpsa+weight+num_rungs_{seeds[i]}.log")
config["max_possible_oracle_score"] = oracle.max_possible_oracle_score
optimize(
model, tokenizer,
Expand Down
7 changes: 5 additions & 2 deletions chemlactica/mol_opt/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ def create_optimization_entries(num_entries, pool, config):
return optim_entries


def create_molecule_entry(output_text):
def create_molecule_entry(output_text, validate_smiles):
start_smiles_tag, end_smiles_tag = "[START_SMILES]", "[END_SMILES]"
start_ind = output_text.rfind(start_smiles_tag)
end_ind = output_text.rfind(end_smiles_tag)
if start_ind == -1 or end_ind == -1:
return None
generated_smiles = output_text[start_ind+len(start_smiles_tag):end_ind]
if not validate_smiles(generated_smiles):
return None
if len(generated_smiles) == 0:
return None

Expand All @@ -58,7 +60,8 @@ def create_molecule_entry(output_text):
def optimize(
model, tokenizer,
oracle, config,
additional_properties={}
additional_properties={},
validate_smiles=lambda x:True
):
file = open(config["log_dir"], "w")
print("config", config)
Expand Down

0 comments on commit 59d754a

Please sign in to comment.