Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modified fingerprint db to allow storing and retrieval of metadata. #73

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
167 changes: 167 additions & 0 deletions catkit/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,12 @@ def create_table(self):
description TEXT
)""")

self.c.execute("""CREATE TABLE IF NOT EXISTS metadata_params(
pid INTEGER PRIMARY KEY AUTOINCREMENT,
symbol CHAR(10) UNIQUE NOT NULL,
description TEXT
)""")

self.c.execute("""CREATE TABLE IF NOT EXISTS fingerprints(
entry_id INTEGER PRIMARY KEY AUTOINCREMENT,
image_id INT NOT NULL,
Expand All @@ -439,6 +445,18 @@ def create_table(self):
UNIQUE(image_id, param_id)
)""")

self.c.execute("""CREATE TABLE IF NOT EXISTS metadata(
entry_id INTEGER PRIMARY KEY AUTOINCREMENT,
image_id INT NOT NULL,
param_id INT NOT NULL,
value TEXT,
FOREIGN KEY(image_id) REFERENCES images(image_id)
ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY(param_id) REFERENCES parameters(param_id)
ON DELETE CASCADE ON UPDATE CASCADE,
UNIQUE(image_id, param_id)
)""")

def image_entry(self, d, identity=None):
"""Enters a single ase-db image into the fingerprint database.
The ase-db ID with identity must be unique. If not, it will be skipped.
Expand Down Expand Up @@ -535,6 +553,72 @@ def get_parameters(self, selection=None, display=False):

return parameter_ids

def metadata_params_entry(self, symbol=None, description=None):
"""Enters a unique metadata parameter into the database.

Parameters
----------
symbol : str
A unique symbol the entry can be referenced by. If None,
the symbol will be the ID of the parameter
as a string.
description : str
A description of the parameter.
"""
if not symbol:
self.c.execute("""SELECT MAX(pid) FROM metadata_params""")
symbol = str(int(self.c.fetchone()[0]) + 1)

# The symbol must be unique. If not, it will be skipped.
try:
self.c.execute("""INSERT INTO metadata_params (symbol, description)
VALUES(?, ?)""", (symbol, description))
except (IntegrityError):
if self.verbose:
print('Symbol already defined: {}'.format(symbol))

# Each instance needs to be commited to ensure no overwriting.
# This could potentially result in slowdown.
self.con.commit()

def get_metadata_params(self, selection=None, display=False):
"""Get an array of integer values which correspond to the
metadata parameter IDs for a set of provided symbols.

Parameters
----------
selection : list
Symbols in parameters table to be selected. If no selection
is made, return all parameters.
display : bool
Print parameter descriptions.

Returns
-------
metadata_params_ids : array (n,)
Integer values of selected parameters.
"""
if not selection:
self.c.execute("""SELECT pid, symbol, description
FROM metadata_params""")
res = self.c.fetchall()
else:
res = []
for i, s in enumerate(selection):
self.c.execute("""SELECT pid, symbol, description
FROM metadata_params WHERE symbol = '{}'""".format(s))
res += [self.c.fetchone()]

if display:
print('[ID ]: key - Description')
print('---------------------------')
for r in res:
print('[{0:^3}]: {1:<10} - {2}'.format(*r))

metadata_params_ids = np.array(res).T[0].astype(int)

return metadata_params_ids

def fingerprint_entry(self, ase_id, param_id, value):
"""Enters a fingerprint value to the database for a given ase and
parameter id.
Expand Down Expand Up @@ -616,3 +700,86 @@ def get_fingerprints(self, ase_ids=None, params=[]):
fingerprint[i] = f[0].split(',')

return fingerprint

def metadata_entry(self, ase_id, param_id, value):
"""Enters a metadata value to the database for a given ase and
parameter id.

Parameters
----------
ase_id : int
The unique id associated with an atoms object in the database.
param_id : int or str
The parameter ID or symbol associated with and entry in the
parameters table.
value : str
The value of the metadata for the atoms object.
"""
# If parameter symbol is given, get the ID
if isinstance(param_id, str):
self.c.execute("""SELECT pid FROM metadata_params
WHERE symbol = '{}'""".format(param_id))
param_id = self.c.fetchone()

if param_id:
param_id = param_id[0]
else:
raise (KeyError, 'metadata symbol not found')

self.c.execute("""SELECT iid FROM images
WHERE ase_id = {}""".format(ase_id))
image_id = self.c.fetchone()[0]

self.c.execute("""INSERT INTO metadata (image_id, param_id, value)
VALUES(?, ?, ?)""", (int(image_id), int(param_id), value))

def get_metadata(self, ase_ids=None, params=[]):
"""Get the array of values associated with the provided metadata
parameters for each ase_id.

Parameters
----------
ase_id : list
The ase-id associated with an atoms object in the database.
params : list
List of symbols or int in metadata parameters table to be selected.

Returns
-------
metadata : array (n,)
An array of values associated with the given metadata parameters
for each ase_id.
"""
if isinstance(params, np.ndarray):
params = params.tolist()

if not params or isinstance(params[0], str):
params = self.get_metadata_params(selection=params)
psel = ','.join(params.astype(str))
elif isinstance(params[0], int):
psel = ','.join(np.array(params).astype(str))

if ase_ids is None:
cmd = """SELECT GROUP_CONCAT(IFNULL(value, 'nan')) FROM
metadata JOIN images on metadata.image_id = images.iid
WHERE param_id IN ({})
GROUP BY ase_id
ORDER BY images.iid""".format(psel)

else:
asel = ','.join(np.array(ase_ids).astype(str))

cmd = """SELECT GROUP_CONCAT(IFNULL(value, 'nan')) FROM
metadata JOIN images on metadata.image_id = images.iid
WHERE param_id IN ({}) AND ase_id IN ({})
GROUP BY ase_id""".format(psel, asel)

self.c.execute(cmd)
fetch = self.c.fetchall()

metadata = []
for i, f in enumerate(fetch):
metadata += [f[0].split(',')]

return metadata

5 changes: 4 additions & 1 deletion catkit/pawprint/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,10 @@ def get_fp(self, parameters, operation_list):
connectivity,
**kwargs)
fingerprints[i] += [fingerprint]
fingerprints = np.block(fingerprints)
try:
fingerprints = np.block(fingerprints)
except ValueError:
fingerprints = 'I\'m outta here.'

return fingerprints

Expand Down
52 changes: 47 additions & 5 deletions catkit/pawprint/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,63 @@ def local_ads_metal_fp(
connectivity=None,
fuse=False):
"""Sum of the differences in properties of the atoms in the
metal-adsorbate interface
metal-adsorbate interface

Parameters
----------
atoms : ase Atoms or catkit gratoms object.
atoms_parameters : ndarray(n, )
a list of chemical properties to construct the fingerprints.
connectivity : ndarray (n,)
Connectivity of the adsorption sites
weigthed : boolean
fingerprints are weightd by the stoichiometric ratio of
the bimetals.
"""
bond_index = np.where(atoms.get_tags() == -1)[0]
fp = np.empty([len(atoms_parameters), len(bond_index)])

for i, bi in enumerate(bond_index):
bonded_ap = atoms_parameters[:, np.where(connectivity[bi] == 1)[0]]
fp[:, i] = np.mean(bonded_ap -
atoms_parameters[:, bi].reshape(-1, 1), axis=1)

atoms_parameters[:, bi].reshape(-1, 1), axis=1)
if not fuse:
return fp.reshape(-1)
else:
return fp.sum(axis=1)
return fp.sum(axis=1)

def bimetal_fp(
atoms=None,
atoms_parameters=None,
connectivity=None):
"""The differences in properties of the atoms in the
metal-adsorbate interface

Parameters
----------
atoms : ase Atoms or catkit gratoms object.
atoms_parameters : ndarray(n, )
a list of chemical properties to construct the fingerprints.
"""
fp = np.zeros([len(atoms_parameters)])
metals = set(atoms.get_chemical_symbols())
uap = []
for ap in atoms_parameters:
if not np.isnan(ap).any():
uap += [np.unique(np.round(ap, 3)).tolist()]
else:
uap += [[np.nan, np.nan]]
for i, ap in enumerate(uap):
if len(ap) == 1:
uap[i] = [ap[0], ap[0]]

if len(metals) == 1:
fp = fp.reshape(-1)
elif len(metals) == 2:
fp = np.diff(uap).reshape(-1)
else:
raise NotImplementedError("""This operation is restricted to single
and binary metal system only.""")
return fp

def derived_fp(
atoms=None,
Expand Down
Binary file added catkit/pawprint/tests/.test_generator.py.swp
Binary file not shown.
8 changes: 5 additions & 3 deletions catkit/pawprint/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@ def test_nonlocal_fingerprinting(self):

operations = [
'periodic_convolution',
['periodic_convolution', {'d': 1}]
['periodic_convolution', {'d': 1}],
'bimetal_fp'
]

fp = Fingerprinter(images)
fingerprints = fp.get_fp(parameters, operations)

truth = np.array([
[12432.0, 7.562800000000001, 320.44, 136896.0, 90.7488, 3844.8],
[2028.0, 24.53879999999999, 1200.0, 20280.0, 245.388, 12000.0]])
[12432.0, 7.562800000000001, 320.44, 136896.0, 90.7488, 3844.8,
3.20000e+01, 2.00000e-02, 2.00000e-01], [2028.0, 24.5387999999999,
1200.0, 20280.0, 245.388, 12000.0, 0.00, 0.00, 0.00]])

np.testing.assert_allclose(fingerprints, truth)

Expand Down