diff --git a/catkit/db.py b/catkit/db.py index 84f04d0c..5870e897 100644 --- a/catkit/db.py +++ b/catkit/db.py @@ -427,6 +427,12 @@ def create_table(self): description TEXT )""") + self.c.execute("""CREATE TABLE IF NOT EXISTS metadata_params( + pid INTEGER PRIMARY KEY AUTOINCREMENT, + symbol CHAR(10) UNIQUE NOT NULL, + description TEXT + )""") + self.c.execute("""CREATE TABLE IF NOT EXISTS fingerprints( entry_id INTEGER PRIMARY KEY AUTOINCREMENT, image_id INT NOT NULL, @@ -439,6 +445,18 @@ def create_table(self): UNIQUE(image_id, param_id) )""") + self.c.execute("""CREATE TABLE IF NOT EXISTS metadata( + entry_id INTEGER PRIMARY KEY AUTOINCREMENT, + image_id INT NOT NULL, + param_id INT NOT NULL, + value TEXT, + FOREIGN KEY(image_id) REFERENCES images(image_id) + ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY(param_id) REFERENCES parameters(param_id) + ON DELETE CASCADE ON UPDATE CASCADE, + UNIQUE(image_id, param_id) + )""") + def image_entry(self, d, identity=None): """Enters a single ase-db image into the fingerprint database. The ase-db ID with identity must be unique. If not, it will be skipped. @@ -535,6 +553,72 @@ def get_parameters(self, selection=None, display=False): return parameter_ids + def metadata_params_entry(self, symbol=None, description=None): + """Enters a unique metadata parameter into the database. + + Parameters + ---------- + symbol : str + A unique symbol the entry can be referenced by. If None, + the symbol will be the ID of the parameter + as a string. + description : str + A description of the parameter. + """ + if not symbol: + self.c.execute("""SELECT MAX(pid) FROM metadata_params""") + symbol = str(int(self.c.fetchone()[0]) + 1) + + # The symbol must be unique. If not, it will be skipped. + try: + self.c.execute("""INSERT INTO metadata_params (symbol, description) + VALUES(?, ?)""", (symbol, description)) + except (IntegrityError): + if self.verbose: + print('Symbol already defined: {}'.format(symbol)) + + # Each instance needs to be commited to ensure no overwriting. + # This could potentially result in slowdown. + self.con.commit() + + def get_metadata_params(self, selection=None, display=False): + """Get an array of integer values which correspond to the + metadata parameter IDs for a set of provided symbols. + + Parameters + ---------- + selection : list + Symbols in parameters table to be selected. If no selection + is made, return all parameters. + display : bool + Print parameter descriptions. + + Returns + ------- + metadata_params_ids : array (n,) + Integer values of selected parameters. + """ + if not selection: + self.c.execute("""SELECT pid, symbol, description + FROM metadata_params""") + res = self.c.fetchall() + else: + res = [] + for i, s in enumerate(selection): + self.c.execute("""SELECT pid, symbol, description + FROM metadata_params WHERE symbol = '{}'""".format(s)) + res += [self.c.fetchone()] + + if display: + print('[ID ]: key - Description') + print('---------------------------') + for r in res: + print('[{0:^3}]: {1:<10} - {2}'.format(*r)) + + metadata_params_ids = np.array(res).T[0].astype(int) + + return metadata_params_ids + def fingerprint_entry(self, ase_id, param_id, value): """Enters a fingerprint value to the database for a given ase and parameter id. @@ -616,3 +700,86 @@ def get_fingerprints(self, ase_ids=None, params=[]): fingerprint[i] = f[0].split(',') return fingerprint + + def metadata_entry(self, ase_id, param_id, value): + """Enters a metadata value to the database for a given ase and + parameter id. + + Parameters + ---------- + ase_id : int + The unique id associated with an atoms object in the database. + param_id : int or str + The parameter ID or symbol associated with and entry in the + parameters table. + value : str + The value of the metadata for the atoms object. + """ + # If parameter symbol is given, get the ID + if isinstance(param_id, str): + self.c.execute("""SELECT pid FROM metadata_params + WHERE symbol = '{}'""".format(param_id)) + param_id = self.c.fetchone() + + if param_id: + param_id = param_id[0] + else: + raise (KeyError, 'metadata symbol not found') + + self.c.execute("""SELECT iid FROM images + WHERE ase_id = {}""".format(ase_id)) + image_id = self.c.fetchone()[0] + + self.c.execute("""INSERT INTO metadata (image_id, param_id, value) + VALUES(?, ?, ?)""", (int(image_id), int(param_id), value)) + + def get_metadata(self, ase_ids=None, params=[]): + """Get the array of values associated with the provided metadata + parameters for each ase_id. + + Parameters + ---------- + ase_id : list + The ase-id associated with an atoms object in the database. + params : list + List of symbols or int in metadata parameters table to be selected. + + Returns + ------- + metadata : array (n,) + An array of values associated with the given metadata parameters + for each ase_id. + """ + if isinstance(params, np.ndarray): + params = params.tolist() + + if not params or isinstance(params[0], str): + params = self.get_metadata_params(selection=params) + psel = ','.join(params.astype(str)) + elif isinstance(params[0], int): + psel = ','.join(np.array(params).astype(str)) + + if ase_ids is None: + cmd = """SELECT GROUP_CONCAT(IFNULL(value, 'nan')) FROM + metadata JOIN images on metadata.image_id = images.iid + WHERE param_id IN ({}) + GROUP BY ase_id + ORDER BY images.iid""".format(psel) + + else: + asel = ','.join(np.array(ase_ids).astype(str)) + + cmd = """SELECT GROUP_CONCAT(IFNULL(value, 'nan')) FROM + metadata JOIN images on metadata.image_id = images.iid + WHERE param_id IN ({}) AND ase_id IN ({}) + GROUP BY ase_id""".format(psel, asel) + + self.c.execute(cmd) + fetch = self.c.fetchall() + + metadata = [] + for i, f in enumerate(fetch): + metadata += [f[0].split(',')] + + return metadata + diff --git a/catkit/pawprint/generator.py b/catkit/pawprint/generator.py index afa8dcdf..c14e04e7 100644 --- a/catkit/pawprint/generator.py +++ b/catkit/pawprint/generator.py @@ -125,7 +125,10 @@ def get_fp(self, parameters, operation_list): connectivity, **kwargs) fingerprints[i] += [fingerprint] - fingerprints = np.block(fingerprints) + try: + fingerprints = np.block(fingerprints) + except ValueError: + fingerprints = 'I\'m outta here.' return fingerprints diff --git a/catkit/pawprint/operations.py b/catkit/pawprint/operations.py index 43cb3b16..54e7ccfc 100644 --- a/catkit/pawprint/operations.py +++ b/catkit/pawprint/operations.py @@ -86,21 +86,63 @@ def local_ads_metal_fp( connectivity=None, fuse=False): """Sum of the differences in properties of the atoms in the - metal-adsorbate interface + metal-adsorbate interface + + Parameters + ---------- + atoms : ase Atoms or catkit gratoms object. + atoms_parameters : ndarray(n, ) + a list of chemical properties to construct the fingerprints. + connectivity : ndarray (n,) + Connectivity of the adsorption sites + weigthed : boolean + fingerprints are weightd by the stoichiometric ratio of + the bimetals. """ bond_index = np.where(atoms.get_tags() == -1)[0] fp = np.empty([len(atoms_parameters), len(bond_index)]) - for i, bi in enumerate(bond_index): bonded_ap = atoms_parameters[:, np.where(connectivity[bi] == 1)[0]] fp[:, i] = np.mean(bonded_ap - - atoms_parameters[:, bi].reshape(-1, 1), axis=1) - + atoms_parameters[:, bi].reshape(-1, 1), axis=1) if not fuse: return fp.reshape(-1) else: - return fp.sum(axis=1) + return fp.sum(axis=1) +def bimetal_fp( + atoms=None, + atoms_parameters=None, + connectivity=None): + """The differences in properties of the atoms in the + metal-adsorbate interface + + Parameters + ---------- + atoms : ase Atoms or catkit gratoms object. + atoms_parameters : ndarray(n, ) + a list of chemical properties to construct the fingerprints. + """ + fp = np.zeros([len(atoms_parameters)]) + metals = set(atoms.get_chemical_symbols()) + uap = [] + for ap in atoms_parameters: + if not np.isnan(ap).any(): + uap += [np.unique(np.round(ap, 3)).tolist()] + else: + uap += [[np.nan, np.nan]] + for i, ap in enumerate(uap): + if len(ap) == 1: + uap[i] = [ap[0], ap[0]] + + if len(metals) == 1: + fp = fp.reshape(-1) + elif len(metals) == 2: + fp = np.diff(uap).reshape(-1) + else: + raise NotImplementedError("""This operation is restricted to single + and binary metal system only.""") + return fp def derived_fp( atoms=None, diff --git a/catkit/pawprint/tests/.test_generator.py.swp b/catkit/pawprint/tests/.test_generator.py.swp new file mode 100644 index 00000000..28a755aa Binary files /dev/null and b/catkit/pawprint/tests/.test_generator.py.swp differ diff --git a/catkit/pawprint/tests/test_generator.py b/catkit/pawprint/tests/test_generator.py index 94546164..c773dfca 100644 --- a/catkit/pawprint/tests/test_generator.py +++ b/catkit/pawprint/tests/test_generator.py @@ -27,15 +27,17 @@ def test_nonlocal_fingerprinting(self): operations = [ 'periodic_convolution', - ['periodic_convolution', {'d': 1}] + ['periodic_convolution', {'d': 1}], + 'bimetal_fp' ] fp = Fingerprinter(images) fingerprints = fp.get_fp(parameters, operations) truth = np.array([ - [12432.0, 7.562800000000001, 320.44, 136896.0, 90.7488, 3844.8], - [2028.0, 24.53879999999999, 1200.0, 20280.0, 245.388, 12000.0]]) + [12432.0, 7.562800000000001, 320.44, 136896.0, 90.7488, 3844.8, + 3.20000e+01, 2.00000e-02, 2.00000e-01], [2028.0, 24.5387999999999, + 1200.0, 20280.0, 245.388, 12000.0, 0.00, 0.00, 0.00]]) np.testing.assert_allclose(fingerprints, truth)