Skip to content

Commit

Permalink
TMP
Browse files Browse the repository at this point in the history
  • Loading branch information
alongd committed Aug 11, 2024
1 parent ca05315 commit c9f3a7a
Show file tree
Hide file tree
Showing 11 changed files with 870 additions and 1,013 deletions.
1 change: 0 additions & 1 deletion arc/job/adapters/ts/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def execute_incore(self):
charge=rxn.charge,
multiplicity=rxn.multiplicity,
)
rxn.arc_species_from_rmg_reaction()
reactants, products = rxn.get_reactants_and_products(arc=True, return_copies=True)
reactant_mol_combinations = list(itertools.product(*list(reactant.mol_list for reactant in reactants)))
product_mol_combinations = list(itertools.product(*list(product.mol_list for product in products)))
Expand Down
11 changes: 3 additions & 8 deletions arc/job/adapters/ts/kinbot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@
"""

import os
import pytest
import shutil
import unittest

from rmgpy.reaction import Reaction
from rmgpy.species import Species

from arc.common import ARC_PATH
from arc.job.adapters.ts.kinbot_ts import KinBotAdapter, HAS_KINBOT
from arc.reaction import ARCReaction
from arc.species import ARCSpecies


class TestKinBotAdapter(unittest.TestCase):
Expand All @@ -34,10 +31,8 @@ def setUpClass(cls):
def test_intra_h_migration(self):
"""Test KinBot for intra H migration reactions"""
if HAS_KINBOT:
rxn1 = ARCReaction(reactants=['CC[O]'], products=['[CH2]CO'])
rxn1.rmg_reaction = Reaction(reactants=[Species().from_smiles('CC[O]')],
products=[Species().from_smiles('[CH2]CO')])
rxn1.arc_species_from_rmg_reaction()
rxn1 = ARCReaction(r_species=[ARCSpecies(label='R1', smiles='CC[O]')],
p_species=[ARCSpecies(label='P1', smiles='[CH2]CO')])
self.assertEqual(rxn1.family, 'intra_H_migration')
kinbot1 = KinBotAdapter(job_type='tsg',
reactions=[rxn1],
Expand Down
6 changes: 4 additions & 2 deletions arc/mapping/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,12 @@ def map_rxn(rxn: 'ARCReaction',
"""
r_label_dict, p_label_dict = get_atom_indices_of_labeled_atoms_in_a_reaction(arc_reaction=rxn)

assign_labels_to_products(rxn, p_label_dict)
assign_labels_to_products(rxn)

reactants, products = copy_species_list_for_mapping(rxn.r_species), copy_species_list_for_mapping(rxn.p_species)
label_species_atoms(reactants), label_species_atoms(products)

r_bdes, p_bdes = find_all_bdes(rxn, r_label_dict, True), find_all_bdes(rxn, p_label_dict, False)
r_bdes, p_bdes = find_all_bdes(rxn, True), find_all_bdes(rxn, False)

r_cuts = cut_species_based_on_atom_indices(reactants, r_bdes)
p_cuts = cut_species_based_on_atom_indices(products, p_bdes)
Expand All @@ -242,6 +242,8 @@ def map_rxn(rxn: 'ARCReaction',
r_cuts, p_cuts = update_xyz(r_cuts), update_xyz(p_cuts)

pairs_of_reactant_and_products = pairing_reactants_and_products_for_mapping(r_cuts, p_cuts)
for p_tup in pairs_of_reactant_and_products:
print(f'\npairs_of_reactant_and_products: {[s.mol for s in p_tup]}\n\n\n')
if len(p_cuts):
logger.error(f"Could not find isomorphism for scissored species: {[cut.mol.smiles for cut in p_cuts]}")
return None
Expand Down
852 changes: 426 additions & 426 deletions arc/mapping/driver_test.py

Large diffs are not rendered by default.

47 changes: 31 additions & 16 deletions arc/mapping/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,7 @@ def make_bond_changes(rxn: 'ARCReaction',
r_cuts: the cut products
r_label_dict: the dictionary object the find the relevant location.
"""
for action in rxn.family.forward_recipe.actions:
for action in ReactionFamily(label=rxn.family).actions:
if action[0].lower() == "CHANGE_BOND".lower():
indicies = r_label_dict[action[1]],r_label_dict[action[3]]
for r_cut in r_cuts:
Expand Down Expand Up @@ -1130,7 +1130,7 @@ def make_bond_changes(rxn: 'ARCReaction',
r_cut.mol.update()


def get_p_label_dict(rxn: 'ARCReaction'):
def get_label_dict(rxn: 'ARCReaction') -> Optional[Dict[str, int]]:
"""
A function that returns the labels of the products.
Expand All @@ -1151,19 +1151,28 @@ def assign_labels_to_products(rxn: 'ARCReaction'):
Add the indices to the reactants and products.
Args:
rxn: ARCReaction object to be mapped
Consider changing in rmgpy.
rxn ('ARCReaction'): The reaction to be mapped.
Returns:
Adding labels to the atoms of the reactants and products, to be identified later.
"""
p_label_dict = get_p_label_dict(rxn)
print(p_label_dict)
label_dict = get_label_dict(rxn)
print(f'\n\nlabel_dict: {label_dict}\n\n')
atom_index = 0
for r in rxn.r_species:
for atom in r.mol.atoms:
if atom_index in label_dict.values():
atom.label = key_by_val(label_dict, atom_index)
atom_index += 1




atom_index = 0
for product in rxn.p_species:
for atom in product.mol.atoms:
if atom_index in p_label_dict.values() and (atom.label is str or atom.label is None):
atom.label = key_by_val(p_label_dict, atom_index)
if atom_index in label_dict.values() and (atom.label is str or atom.label is None):
atom.label = key_by_val(label_dict, atom_index)
atom_index += 1


Expand Down Expand Up @@ -1206,7 +1215,7 @@ def r_cut_p_cut_isomorphic(reactant, product):

def pairing_reactants_and_products_for_mapping(r_cuts: List[ARCSpecies],
p_cuts: List[ARCSpecies]
)-> List[Tuple[ARCSpecies,ARCSpecies]]:
) -> List[Tuple[ARCSpecies,ARCSpecies]]:
"""
A function for matching reactants and products in scissored products.
Expand All @@ -1222,7 +1231,7 @@ def pairing_reactants_and_products_for_mapping(r_cuts: List[ARCSpecies],
for product_cut in p_cuts:
if r_cut_p_cut_isomorphic(reactant_cut, product_cut):
pairs.append((reactant_cut, product_cut))
p_cuts.remove(product_cut) # Just in case there are two of the same species in the list, matching them by order.
p_cuts.remove(product_cut) # Just in case there are two of the same species in the list, matching them by order.
break
return pairs

Expand Down Expand Up @@ -1361,18 +1370,24 @@ def copy_species_list_for_mapping(species: List["ARCSpecies"]) -> List["ARCSpeci
return copies


def find_all_bdes(rxn: "ARCReaction", label_dict: dict, is_reactants: bool) -> List[Tuple[int, int]]:
def find_all_bdes(rxn: "ARCReaction",
is_reactants: bool,
) -> List[Tuple[int, int]]:
"""
A function for finding all the broken(/formed) bonds during a chemical reaction, based on the atom indices.
Args:
rxn (ARCReaction): The reaction in question.
label_dict (dict): A dictionary of the atom indices to the atom labels.
is_reactants (bool): Whether the species list represents reactants or products.
Returns:
List[Tuple[int, int]]: A list of tuples of the form (atom_index1, atom_index2) for each broken bond. Note that these represent the atom indicies to be cut, and not final BDEs.
List[Tuple[int, int]]: A list of tuples of the form (atom_index1, atom_index2) for each broken bond.
Note that these represent the atom indices to be cut, and not final BDEs.
"""
label_dict = get_label_dict(rxn)
bdes = list()
for action in ReactionFamily(rxn.family).actions:
if action[0].lower() == ("break_bond" if is_reactants else "form_bond"):
bdes.append((label_dict[action[1]] + 1, label_dict[action[3]] + 1))
if rxn.family is not None:
for action in ReactionFamily(rxn.family).actions:
if action[0].lower() == ("break_bond" if is_reactants else "form_bond"):
bdes.append((label_dict[action[1]] + 1, label_dict[action[3]] + 1))
return bdes
Loading

0 comments on commit c9f3a7a

Please sign in to comment.