Skip to content

Commit

Permalink
TMP
Browse files Browse the repository at this point in the history
  • Loading branch information
alongd committed Aug 12, 2024
1 parent c9f3a7a commit 9a69670
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 48 deletions.
2 changes: 1 addition & 1 deletion arc/mapping/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def map_rxn(rxn: 'ARCReaction',
"""
r_label_dict, p_label_dict = get_atom_indices_of_labeled_atoms_in_a_reaction(arc_reaction=rxn)

assign_labels_to_products(rxn)
assign_labels_to_products(rxn=rxn, products=rxn.get_family_products())

reactants, products = copy_species_list_for_mapping(rxn.r_species), copy_species_list_for_mapping(rxn.p_species)
label_species_atoms(reactants), label_species_atoms(products)
Expand Down
69 changes: 44 additions & 25 deletions arc/mapping/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def pair_reaction_products(reaction: 'ARCReaction',
products (List[ARCSpecies]): Species that correspond to the ARCReaction products that require pairing.
Returns:
Dict[int, int]: Keys are specie indices in the ARC reaction, values are respective indices in the product list.
Dict[int, int]: Keys are species indices in the ARC reaction, values are respective indices in the product list.
"""
if reaction.is_isomerization():
return {0: 0}
Expand Down Expand Up @@ -225,7 +225,7 @@ def pair_reaction_products(reaction: 'ARCReaction',
# Returns:
# Tuple[Dict[int, Union[List[int], int]], Dict[int, Union[List[int], int]]]:
# The first tuple entry refers to reactants, the second to products.
# Keys are specie indices in the ARC reaction,
# Keys are specied indices in the ARC reaction,
# values are respective indices in the RMG reaction.
# If ``concatenate`` is ``True``, values are lists of integers. Otherwise, values are integers.
# """
Expand Down Expand Up @@ -1146,33 +1146,34 @@ def get_label_dict(rxn: 'ARCReaction') -> Optional[Dict[str, int]]:
return None


def assign_labels_to_products(rxn: 'ARCReaction'):
def assign_labels_to_products(rxn: 'ARCReaction',
products: List[Molecule],
):
"""
Add the indices to the reactants and products.
Args:
rxn ('ARCReaction'): The reaction to be mapped.
products (List[Molecule]): The products generated from the RMG family with the same atom order as the reactants.
Returns:
Adding labels to the atoms of the reactants and products, to be identified later.
"""
label_dict = get_label_dict(rxn)
print(f'\n\nlabel_dict: {label_dict}\n\n')
atom_index = 0
for r in rxn.r_species:
for atom in r.mol.atoms:
if atom_index in label_dict.values():
atom.label = key_by_val(label_dict, atom_index)
atom_index += 1




product_pairs = pair_reaction_products(reaction=rxn, products=products)
atom_index = 0
for product in rxn.p_species:
for atom in product.mol.atoms:
if atom_index in label_dict.values() and (atom.label is str or atom.label is None):
atom.label = key_by_val(label_dict, atom_index)
for product_index in range(len(products)):
rxn_product, fam_product = rxn.p_species[product_index], products[product_pairs[product_index]]
atom_map = map_two_species(spc_1=rxn_product, spc_2=fam_product, map_type='list')
for i, atom in enumerate(fam_product.atoms):
if atom_index in label_dict.values():
rxn_product.mol.atoms[atom_map[i]].label = key_by_val(label_dict, atom_index)
atom_index += 1


Expand All @@ -1185,7 +1186,7 @@ def update_xyz(spcs: List[ARCSpecies]) -> List[ARCSpecies]:
spcs: the scission products that needs to be updated
Returns:
new: A newely generated copies of the ARCSpecies, with updated xyz
list: A newly generated copies of the ARCSpecies, with updated xyz.
"""
new = list()
for spc in spcs:
Expand Down Expand Up @@ -1224,7 +1225,7 @@ def pairing_reactants_and_products_for_mapping(r_cuts: List[ARCSpecies],
p_cuts: A list of the scissored species in the reactants
Returns:
a list of paired reactant and products, to be sent to map_two_species.
list: Paired reactant and products, to be sent to map_two_species.
"""
pairs = []
for reactant_cut in r_cuts:
Expand All @@ -1238,19 +1239,17 @@ def pairing_reactants_and_products_for_mapping(r_cuts: List[ARCSpecies],

def map_pairs(pairs):
"""
A function that maps the mached species together
A function that maps the matched species together
Args:
pairs: A list of the pairs of reactants and species
Returns:
A list of the mapped species
"""

maps = list()
for pair in pairs:
maps.append(map_two_species(pair[0], pair[1]))

return maps


Expand All @@ -1261,23 +1260,23 @@ def label_species_atoms(spcs):
Args:
spcs: ARCSpecies object to be labeled.
"""
index=0
index = 0
for spc in spcs:
for atom in spc.mol.atoms:
atom.label = str(index)
index+=1
index += 1


def glue_maps(maps, pairs_of_reactant_and_products):
"""
a function that joins together the maps from the parts of the reaction.
Args:
rxn: ARCReaction that requires atom mapping
maps: The list of all maps of the isomorphic cuts.
pairs_of_reactant_and_products: The pairs of the reactants and products.
Returns:
an Atom Map of the compleate reaction.
list: An Atom Map of the complete reaction.
"""
am_dict = dict()
for _map, pair in zip(maps, pairs_of_reactant_and_products):
Expand Down Expand Up @@ -1319,18 +1318,21 @@ def determine_bdes_on_spc_based_on_atom_labels(spc: "ARCSpecies", bde: Tuple[int
return False


def cut_species_based_on_atom_indices(species: List["ARCSpecies"], bdes: List[Tuple[int, int]]) -> Optional[List["ARCSpecies"]]:
def cut_species_based_on_atom_indices(species: List["ARCSpecies"],
bdes: List[Tuple[int, int]],
) -> Optional[List["ARCSpecies"]]:
"""
A function for scissoring species based on their atom indices.
Args:
species (List[ARCSpecies]): The species list that requires scission.
bdes (List[Tuple[int, int]]): A list of the atoms between which the bond should be scissored. The atoms are described using the atom labels, and not the actuall atom positions.
Returns:
Optional[List["ARCSpecies"]]: The species list input after the scission.
"""
if not bdes:
return species

for bde in bdes:
for index, spc in enumerate(species):
if determine_bdes_on_spc_based_on_atom_labels(spc, bde):
Expand All @@ -1351,7 +1353,6 @@ def cut_species_based_on_atom_indices(species: List["ARCSpecies"], bdes: List[Tu
except SpeciesError:
return None
break

return species


Expand All @@ -1372,22 +1373,40 @@ def copy_species_list_for_mapping(species: List["ARCSpecies"]) -> List["ARCSpeci

def find_all_bdes(rxn: "ARCReaction",
is_reactants: bool,
products: Optional[List["Molecule"]] = None,
) -> List[Tuple[int, int]]:
"""
A function for finding all the broken(/formed) bonds during a chemical reaction, based on the atom indices.
Args:
rxn (ARCReaction): The reaction in question.
is_reactants (bool): Whether the species list represents reactants or products.
products (List[Molecule], optional): The products generated from the RMG family with the same atom order
as the reactants. If given, the BDE values will be mapped from them
to the reaction products.
Returns:
List[Tuple[int, int]]: A list of tuples of the form (atom_index1, atom_index2) for each broken bond.
Note that these represent the atom indices to be cut, and not final BDEs.
"""
label_dict = get_label_dict(rxn)
if not is_reactants:
product_pairs = pair_reaction_products(reaction=rxn, products=products) if products is not None else None
print(label_dict)
bdes = list()
if rxn.family is not None:
for action in ReactionFamily(rxn.family).actions:
if action[0].lower() == ("break_bond" if is_reactants else "form_bond"):
print(action)
if (action[0].lower() == "break_bond" and is_reactants
or action[0].lower() == "form_bond" and not is_reactants):
print(f'appending {action[1]} and {action[3]}: {(label_dict[action[1]] + 1, label_dict[action[3]] + 1)}')
bdes.append((label_dict[action[1]] + 1, label_dict[action[3]] + 1))
return bdes


rxn_product, fam_product = rxn.p_species[product_index], products[product_pairs[product_index]]
atom_map = map_two_species(spc_1=rxn_product, spc_2=fam_product, map_type='list')
for i, atom in enumerate(fam_product.atoms):
if atom_index in label_dict.values():
rxn_product.mol.atoms[atom_map[i]].label = key_by_val(label_dict, atom_index)
atom_index += 1
37 changes: 17 additions & 20 deletions arc/mapping/engine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from random import shuffle
import itertools

from arc.common import _check_r_n_p_symbols_between_rmg_and_arc_rxns
from arc.mapping.engine import *
from arc.reaction import ARCReaction

Expand Down Expand Up @@ -512,16 +511,11 @@ def test_pair_reaction_products(self):
def test_assign_labels_to_products(self):
"""Test assigning labels to products based on the atom map of the reaction"""
rxn_1_test = ARCReaction(r_species=[self.r_1, self.r_2], p_species=[self.p_1, self.p_2])
assign_labels_to_products(rxn_1_test)
print([atom.label for atom in rxn_1_test.p_species[0].mol.atoms])
index = 0
for product in rxn_1_test.p_species:
print(product.label, index)
for atom in product.mol.atoms:
if not isinstance(atom.label, str) or atom.label != "":
print(atom.label, index)
self.assertEqual(self.p_label_dict_rxn_1[atom.label], index)
index += 1
assign_labels_to_products(rxn_1_test, rxn_1_test.get_family_products())
self.assertEqual([atom.label for atom in rxn_1_test.r_species[0].mol.atoms], ['*3', '', ''])
self.assertEqual([atom.label for atom in rxn_1_test.r_species[1].mol.atoms], ['*1', '', '*2'])
self.assertEqual([atom.label for atom in rxn_1_test.p_species[0].mol.atoms], ['*3', '', '*1', ''])
self.assertEqual([atom.label for atom in rxn_1_test.p_species[1].mol.atoms], ['', '*2'])

def test_inc_vals(self):
"""Test creating an atom map via map_two_species() and incrementing all values"""
Expand Down Expand Up @@ -550,19 +544,22 @@ def test_label_species_atoms(self):
def test_cut_species_based_on_atom_indices(self):
"""test the cut_species_for_mapping function"""
rxn_1_test = ARCReaction(r_species=[self.r_1, self.r_2], p_species=[self.p_1, self.p_2],
rmg_family_set=['F_Abstraction'])
reactants, products = copy_species_list_for_mapping(rxn_1_test.r_species), copy_species_list_for_mapping(rxn_1_test.p_species)
rmg_family_set=['H_Abstraction'])
reactants = copy_species_list_for_mapping(rxn_1_test.r_species)
products = copy_species_list_for_mapping(rxn_1_test.p_species)
label_species_atoms(reactants), label_species_atoms(products)
r_bdes, p_bdes = find_all_bdes(rxn_1_test, True), find_all_bdes(rxn_1_test, False)
print(r_bdes, p_bdes)
r_cuts = cut_species_based_on_atom_indices(reactants, r_bdes)
p_cuts = cut_species_based_on_atom_indices(products, p_bdes)

self.assertIn("C[CH]C", [r_cut.mol.copy(deep=True).smiles for r_cut in r_cuts])
self.assertIn("[F]", [r_cut.mol.copy(deep=True).smiles for r_cut in r_cuts])
self.assertIn("[CH3]", [r_cut.mol.copy(deep=True).smiles for r_cut in r_cuts])
self.assertIn("C[CH]C", [p_cut.mol.copy(deep=True).smiles for p_cut in p_cuts])
self.assertIn("[F]", [p_cut.mol.copy(deep=True).smiles for p_cut in p_cuts])
self.assertIn("[CH3]", [p_cut.mol.copy(deep=True).smiles for p_cut in p_cuts])
print([a.mol for a in r_cuts], [a.mol for a in p_cuts])

self.assertIn('[C]#CF', [r_cut.mol.copy(deep=True).smiles for r_cut in r_cuts])
self.assertIn('[C]#N', [r_cut.mol.copy(deep=True).smiles for r_cut in r_cuts])
self.assertIn('[H]', [r_cut.mol.copy(deep=True).smiles for r_cut in r_cuts])
self.assertIn('[C]#CF', [p_cut.mol.copy(deep=True).smiles for p_cut in p_cuts])
self.assertIn('[C]#N', [p_cut.mol.copy(deep=True).smiles for p_cut in p_cuts])
self.assertIn('[H]', [p_cut.mol.copy(deep=True).smiles for p_cut in p_cuts])

spc = ARCSpecies(label="test", smiles="CNC", bdes=[(1, 2), (2, 3)])
for i, a in enumerate(spc.mol.atoms):
Expand Down
22 changes: 20 additions & 2 deletions arc/reaction/reaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
A module for representing a reaction.
"""

from typing import Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

from arkane.common import get_element_mass
from rmgpy.reaction import Reaction
from rmgpy.species import Species

from arc.common import get_logger
Expand All @@ -20,6 +19,9 @@
from arc.mapping.driver import map_reaction
from arc.species.species import ARCSpecies, check_atom_balance, check_label

if TYPE_CHECKING:
from rmgpy.molecule import Molecule


logger = get_logger()

Expand Down Expand Up @@ -533,6 +535,22 @@ def determine_family(self,
return family, family_own_reverse
return None, None

def get_family_products(self) -> Optional[List['Molecule']]:
"""
Determine the RMG reaction family.
Populates the .family, and .family_own_reverse attributes.
Returns:
Optional[List[Molecule]]: The products of the reaction with the same atom order as the reactants,
generated by the family. Currently only returning the first product list.
"""
product_dicts = get_reaction_family_products(rxn=self,
rmg_family_set=[self.family],
)
if len(product_dicts):
return product_dicts[0]['products']
return None

def check_attributes(self):
"""Check that the Reaction object is defined correctly"""
self.set_label_reactants_products()
Expand Down

0 comments on commit 9a69670

Please sign in to comment.