From fc5ad3de6a211aa1dfc7af436d0111fc3e016a99 Mon Sep 17 00:00:00 2001 From: Rakib Hassan Date: Fri, 21 Jun 2024 15:03:57 +1000 Subject: [PATCH 1/3] Added a new class to aggregate inventories without duplication --- .../ASDFdatabase/_FederatedASDFDataSetImpl.py | 28 ++++--- seismic/ASDFdatabase/utils.py | 75 ++++++++++++++++++- 2 files changed, 87 insertions(+), 16 deletions(-) diff --git a/seismic/ASDFdatabase/_FederatedASDFDataSetImpl.py b/seismic/ASDFdatabase/_FederatedASDFDataSetImpl.py index 229b22fb..bb02441e 100644 --- a/seismic/ASDFdatabase/_FederatedASDFDataSetImpl.py +++ b/seismic/ASDFdatabase/_FederatedASDFDataSetImpl.py @@ -29,7 +29,7 @@ import sqlite3 import hashlib from functools import partial -from seismic.ASDFdatabase.utils import MIN_DATE, MAX_DATE, cleanse_inventory +from seismic.ASDFdatabase.utils import MIN_DATE, MAX_DATE, cleanse_inventory, InventoryAggregator from seismic.misc import split_list, setup_logger import pickle as cPickle import pandas as pd @@ -339,6 +339,8 @@ def decode_tag(tag, type='raw_recording'): check_same_thread=self.single_threaded_access) else: if(self.rank==0): + ia = InventoryAggregator() + self.conn = sqlite3.connect(self.db_fn, check_same_thread=self.single_threaded_access) self.conn.execute('create table wdb(ds_id smallint, net varchar(6), sta varchar(6), loc varchar(6), ' @@ -348,7 +350,6 @@ def decode_tag(tag, type='raw_recording'): self.conn.execute('create table masterinv(inv blob)') metadatalist = [] - masterinv = None for ids, ds in enumerate(self.asdf_datasets): coords_dict = ds.get_all_coordinates() @@ -362,25 +363,21 @@ def decode_tag(tag, type='raw_recording'): # end if for k in coords_dict.keys(): - if(not masterinv): - masterinv = ds.waveforms[k].StationXML - else: - try: - masterinv += cleanse_inventory(ds.waveforms[k].StationXML) - except Exception as e: - print(e) - # end try - # end if - # end for - - for k in list(coords_dict.keys()): + # we keep coordinates from all ASDF files to be able to track + # potential discrepancies lon = coords_dict[k]['longitude'] lat = coords_dict[k]['latitude'] elev_m = coords_dict[k]['elevation_in_m'] nc, sc = k.split('.') metadatalist.append([ids, nc, sc, lon, lat, elev_m]) + + # aggregate inventories + inv = cleanse_inventory(ds.waveforms[k].StationXML) + ia.append(inv) # end for # end for + + masterinv = ia.summarize() self.conn.executemany('insert into netsta(ds_id, net, sta, lon, lat, elev_m) values ' '(?, ?, ?, ?, ?, ?)', metadatalist) self.conn.execute('insert into masterinv(inv) values(?)', @@ -392,6 +389,7 @@ def decode_tag(tag, type='raw_recording'): self.conn.commit() self.conn.close() # end if + self.comm.Barrier() tagsCount = 0 for ids, ds in enumerate(self.asdf_datasets): @@ -438,7 +436,7 @@ def decode_tag(tag, type='raw_recording'): print('Creating table indices..') self.conn = sqlite3.connect(self.db_fn, check_same_thread=self.single_threaded_access) - self.conn.execute('create index allindex on wdb(net, sta, loc, cha, st, et)') + self.conn.execute('create index allindex on wdb(ds_id, net, sta, loc, cha, st, et)') self.conn.execute('create index netstaindex on netsta(ds_id, net, sta)') self.conn.commit() self.conn.close() diff --git a/seismic/ASDFdatabase/utils.py b/seismic/ASDFdatabase/utils.py index 16a74cd1..1c76e08f 100644 --- a/seismic/ASDFdatabase/utils.py +++ b/seismic/ASDFdatabase/utils.py @@ -8,8 +8,9 @@ import os from tqdm import tqdm from ordered_set import OrderedSet as set -from seismic.misc import split_list from obspy import Inventory +import obspy +import copy MAX_DATE = UTCDateTime(4102444800.0) #2100-01-01 MIN_DATE = UTCDateTime(-2208988800.0) #1900-01-01 @@ -30,6 +31,78 @@ def cleanse_inventory(iinv: Inventory) -> Inventory: return oinv # end func +class InventoryAggregator: + def __init__(self): + tree = lambda: defaultdict(tree) + self.net_dict = tree() + self.sta_dict = tree() + self.cha_dict = tree() + # end func + + def append(self, inv: Inventory): + for net in inv.networks: + nc = net.code + + if (type(self.net_dict[nc]) == defaultdict): + onet = copy.deepcopy(net) + onet.stations = [] + self.net_dict[nc] = onet + # end if + + for sta in net.stations: + sc = sta.code + + if (type(self.sta_dict[nc][sc]) == defaultdict): + osta = copy.deepcopy(sta) + osta.channels = [] + self.sta_dict[nc][sc] = osta + # end if + + for cha in sta.channels: + cc = cha.code + lc = cha.location_code + + # set responses to None + try: + cc.response = None + except: + pass + + if (type(self.cha_dict[nc][sc][lc][cc]) == defaultdict): + self.cha_dict[nc][sc][lc][cc] = cha + # end if + # end for + # end for + # end for + # end func + + def summarize(self): + oinv = Inventory(networks=[], + source=obspy.core.util.version.read_release_version()) + + for nc in self.net_dict.keys(): + net = self.net_dict[nc] + + for sc in self.sta_dict[nc].keys(): + sta = self.sta_dict[nc][sc] + + for lc in self.cha_dict[nc][sc].keys(): + for cc in self.cha_dict[nc][sc][lc].keys(): + cha = self.cha_dict[nc][sc][lc][cc] + + sta.channels.append(cha) + # end for + # end for + net.stations.append(sta) + # end for + + oinv.networks.append(net) + # end for + + return oinv + # end func +# end class + class MseedIndex: class StreamCache: def __init__(self): From de8d11a049fed6f9a06747e2072120ea6e81a2aa Mon Sep 17 00:00:00 2001 From: Rakib Hassan Date: Fri, 21 Jun 2024 16:05:23 +1000 Subject: [PATCH 2/3] Disabling greedy channel-code matching --- {tutorial => legacy/tutorial}/ASDF_Federated_ASDF.ipynb | 0 {tutorial => legacy/tutorial}/__init__.py | 0 .../tutorial}/get_waveforms_from_fedasdf.py | 0 {tutorial => legacy/tutorial}/inspect_fasdf.py | 0 {tutorial => legacy/tutorial}/pbs_inspect_fasdf.sh | 0 seismic/ASDFdatabase/_FederatedASDFDataSetImpl.py | 8 ++++---- seismic/extract_event_traces.py | 1 + 7 files changed, 5 insertions(+), 4 deletions(-) rename {tutorial => legacy/tutorial}/ASDF_Federated_ASDF.ipynb (100%) rename {tutorial => legacy/tutorial}/__init__.py (100%) rename {tutorial => legacy/tutorial}/get_waveforms_from_fedasdf.py (100%) rename {tutorial => legacy/tutorial}/inspect_fasdf.py (100%) rename {tutorial => legacy/tutorial}/pbs_inspect_fasdf.sh (100%) diff --git a/tutorial/ASDF_Federated_ASDF.ipynb b/legacy/tutorial/ASDF_Federated_ASDF.ipynb similarity index 100% rename from tutorial/ASDF_Federated_ASDF.ipynb rename to legacy/tutorial/ASDF_Federated_ASDF.ipynb diff --git a/tutorial/__init__.py b/legacy/tutorial/__init__.py similarity index 100% rename from tutorial/__init__.py rename to legacy/tutorial/__init__.py diff --git a/tutorial/get_waveforms_from_fedasdf.py b/legacy/tutorial/get_waveforms_from_fedasdf.py similarity index 100% rename from tutorial/get_waveforms_from_fedasdf.py rename to legacy/tutorial/get_waveforms_from_fedasdf.py diff --git a/tutorial/inspect_fasdf.py b/legacy/tutorial/inspect_fasdf.py similarity index 100% rename from tutorial/inspect_fasdf.py rename to legacy/tutorial/inspect_fasdf.py diff --git a/tutorial/pbs_inspect_fasdf.sh b/legacy/tutorial/pbs_inspect_fasdf.sh similarity index 100% rename from tutorial/pbs_inspect_fasdf.sh rename to legacy/tutorial/pbs_inspect_fasdf.sh diff --git a/seismic/ASDFdatabase/_FederatedASDFDataSetImpl.py b/seismic/ASDFdatabase/_FederatedASDFDataSetImpl.py index bb02441e..fa6394ad 100644 --- a/seismic/ASDFdatabase/_FederatedASDFDataSetImpl.py +++ b/seismic/ASDFdatabase/_FederatedASDFDataSetImpl.py @@ -334,7 +334,7 @@ def decode_tag(tag, type='raw_recording'): self.comm.Barrier() if(dbFound): - print(('Found database: %s'%(self.db_fn))) + print('Found database: %s'%(self.db_fn)) self.conn = sqlite3.connect(self.db_fn, check_same_thread=self.single_threaded_access) else: @@ -393,7 +393,7 @@ def decode_tag(tag, type='raw_recording'): tagsCount = 0 for ids, ds in enumerate(self.asdf_datasets): - if(self.rank==0): print(('Indexing %s..' % (os.path.basename(self.asdf_file_names[ids])))) + if(self.rank==0): print('Indexing %s..' % (os.path.basename(self.asdf_file_names[ids]))) keys = list(ds.get_all_coordinates().keys()) keys = split_list(keys, self.nproc) @@ -420,8 +420,8 @@ def decode_tag(tag, type='raw_recording'): check_same_thread=self.single_threaded_access) self.conn.executemany('insert into wdb(ds_id, net, sta, loc, cha, st, et, tag) values ' '(?, ?, ?, ?, ?, ?, ?, ?)', data) - print(('\tInserted %d entries on rank %d'%(len(data), - self.rank))) + print('\tInserted %d entries on rank %d'%(len(data), + self.rank)) tagsCount += len(data) self.conn.commit() self.conn.close() diff --git a/seismic/extract_event_traces.py b/seismic/extract_event_traces.py index 9e596fc6..b29254eb 100644 --- a/seismic/extract_event_traces.py +++ b/seismic/extract_event_traces.py @@ -167,6 +167,7 @@ def asdf_get_waveforms(asdf_dataset, network, station, location, channel, startt matching_stations = asdf_dataset.get_stations(starttime, endtime, network=network, station=station, location=location) if matching_stations: + channel = channel.replace('?', '.') # replace greedy matching by single-character matching ch_matcher = re.compile(channel) for net, sta, loc, cha, _, _, _ in matching_stations: if ch_matcher.match(cha): From b0f114bfc20ad8e4577db62da37befae0edff1aa Mon Sep 17 00:00:00 2001 From: Rakib Hassan Date: Fri, 21 Jun 2024 20:38:42 +1000 Subject: [PATCH 3/3] Renamed tables in waveform database --- .../ASDFdatabase/_FederatedASDFDataSetImpl.py | 44 ++++++++++--------- .../ASDFdatabase/export_station_locations.py | 10 ++--- .../ASDFdatabase/test_federatedasdfdataset.py | 2 +- 3 files changed, 30 insertions(+), 26 deletions(-) diff --git a/seismic/ASDFdatabase/_FederatedASDFDataSetImpl.py b/seismic/ASDFdatabase/_FederatedASDFDataSetImpl.py index fa6394ad..143abc6a 100644 --- a/seismic/ASDFdatabase/_FederatedASDFDataSetImpl.py +++ b/seismic/ASDFdatabase/_FederatedASDFDataSetImpl.py @@ -343,14 +343,18 @@ def decode_tag(tag, type='raw_recording'): self.conn = sqlite3.connect(self.db_fn, check_same_thread=self.single_threaded_access) - self.conn.execute('create table wdb(ds_id smallint, net varchar(6), sta varchar(6), loc varchar(6), ' + self.conn.execute('create table ds(ds_id smallint, path text)') + self.conn.execute('create table wtag(ds_id smallint, net varchar(6), sta varchar(6), loc varchar(6), ' 'cha varchar(6), st double, et double, tag text)') - self.conn.execute('create table netsta(ds_id smallint, net varchar(6), sta varchar(6), lon double, ' + self.conn.execute('create table meta(ds_id smallint, net varchar(6), sta varchar(6), lon double, ' 'lat double, elev_m double)') self.conn.execute('create table masterinv(inv blob)') metadatalist = [] for ids, ds in enumerate(self.asdf_datasets): + self.conn.execute('insert into ds(ds_id, path) values(?, ?)', + [ids, self.asdf_file_names[ids]]) + coords_dict = ds.get_all_coordinates() # report any missing metadata @@ -378,7 +382,7 @@ def decode_tag(tag, type='raw_recording'): # end for masterinv = ia.summarize() - self.conn.executemany('insert into netsta(ds_id, net, sta, lon, lat, elev_m) values ' + self.conn.executemany('insert into meta(ds_id, net, sta, lon, lat, elev_m) values ' '(?, ?, ?, ?, ?, ?)', metadatalist) self.conn.execute('insert into masterinv(inv) values(?)', [cPickle.dumps(masterinv, cPickle.HIGHEST_PROTOCOL)]) @@ -418,7 +422,7 @@ def decode_tag(tag, type='raw_recording'): if(len(data)): self.conn = sqlite3.connect(self.db_fn, check_same_thread=self.single_threaded_access) - self.conn.executemany('insert into wdb(ds_id, net, sta, loc, cha, st, et, tag) values ' + self.conn.executemany('insert into wtag(ds_id, net, sta, loc, cha, st, et, tag) values ' '(?, ?, ?, ?, ?, ?, ?, ?)', data) print('\tInserted %d entries on rank %d'%(len(data), self.rank)) @@ -436,8 +440,8 @@ def decode_tag(tag, type='raw_recording'): print('Creating table indices..') self.conn = sqlite3.connect(self.db_fn, check_same_thread=self.single_threaded_access) - self.conn.execute('create index allindex on wdb(ds_id, net, sta, loc, cha, st, et)') - self.conn.execute('create index netstaindex on netsta(ds_id, net, sta)') + self.conn.execute('create index allindex on wtag(ds_id, net, sta, loc, cha, st, et)') + self.conn.execute('create index metaindex on meta(ds_id, net, sta)') self.conn.commit() self.conn.close() print('Done..') @@ -448,7 +452,7 @@ def decode_tag(tag, type='raw_recording'): # end if # Load metadata - rows = self.conn.execute('select * from netsta').fetchall() + rows = self.conn.execute('select * from meta').fetchall() for row in rows: ds_id, net, sta, lon, lat, elev_m = row self.asdf_station_coordinates[ds_id]['%s.%s' % (net.strip(), sta.strip())] = [lon, lat, elev_m] @@ -460,7 +464,7 @@ def decode_tag(tag, type='raw_recording'): # end func def get_global_time_range(self, network, station=None, location=None, channel=None): - query = "select min(st), max(et) from wdb where net='%s' "%(network) + query = "select min(st), max(et) from wtag where net='%s' "%(network) if (station is not None): query += "and sta='%s' "%(station) @@ -486,7 +490,7 @@ def get_stations(self, starttime, endtime, network=None, station=None, location= starttime = UTCDateTime(starttime).timestamp endtime = UTCDateTime(endtime).timestamp - query = 'select * from wdb where ' + query = 'select * from wtag where ' if (network): query += " net='%s' "%(network) if (station): if(network): query += "and sta='%s' "%(station) @@ -520,7 +524,7 @@ def get_waveform_count(self, network, station, location, channel, starttime, end starttime = UTCDateTime(starttime).timestamp endtime = UTCDateTime(endtime).timestamp - query = "select count(*) from wdb where net='%s' and sta='%s' and loc='%s' and cha='%s' " \ + query = "select count(*) from wtag where net='%s' and sta='%s' and loc='%s' and cha='%s' " \ %(network, station, location, channel) + \ "and et>=%f and st<=%f" \ % (starttime, endtime) @@ -536,7 +540,7 @@ def get_waveforms(self, network, station, location, channel, starttime, starttime = UTCDateTime(starttime) endtime = UTCDateTime(endtime) - query = "select * from wdb where net='%s' and sta='%s' and loc='%s' and cha='%s' " \ + query = "select * from wtag where net='%s' and sta='%s' and loc='%s' and cha='%s' " \ %(network, station, location, channel) + \ "and et>=%f and st<=%f" \ % (starttime.timestamp, endtime.timestamp) @@ -637,14 +641,14 @@ def stations_iterator(self, network_list=[], station_list=[]): workload.append(defaultdict(partial(defaultdict, list))) # end for - nets = self.conn.execute('select distinct net from wdb').fetchall() + nets = self.conn.execute('select distinct net from wtag').fetchall() if(len(network_list)): # filter networks nets = [net for net in nets if net[0] in network_list] # end if for net in nets: net = net[0] - stas = self.conn.execute("select distinct sta from wdb where net='%s'"%(net)).fetchall() + stas = self.conn.execute("select distinct sta from wtag where net='%s'"%(net)).fetchall() if (len(station_list)): # filter stations stas = [sta for sta in stas if sta[0] in station_list] @@ -654,7 +658,7 @@ def stations_iterator(self, network_list=[], station_list=[]): sta = sta[0] # trace-count, min(st), max(et) - attribs = self.conn.execute("select count(st), min(st), max(et) from wdb where net='%s' and sta='%s'" + attribs = self.conn.execute("select count(st), min(st), max(et) from wtag where net='%s' and sta='%s'" %(net, sta)).fetchall() if(len(attribs)==0): continue @@ -708,7 +712,7 @@ def find_gaps(self, network=None, station=None, location=None, min_gap_length=86400): clause_added = 0 - query = 'select net, sta, loc, cha, st, et from wdb ' + query = 'select net, sta, loc, cha, st, et from wtag ' if (network or station or location or channel or (start_date_ts and end_date_ts)): query += " where " if (network): @@ -816,11 +820,11 @@ def find_gaps(self, network=None, station=None, location=None, def get_coverage(self, network=None): query = """ select w.net, w.sta, w.loc, w.cha, n.lon, n.lat, min(w.st), max(w.et) - from wdb as w, netsta as n where w.net=n.net and w.sta=n.sta and - w.net in (select distinct net from netsta) and - w.sta in (select distinct sta from netsta) and - w.loc in (select distinct loc from wdb) and - w.cha in (select distinct cha from wdb) + from wtag as w, meta as n where w.net=n.net and w.sta=n.sta and + w.net in (select distinct net from meta) and + w.sta in (select distinct sta from meta) and + w.loc in (select distinct loc from wtag) and + w.cha in (select distinct cha from wtag) """ if(network): query += ' and w.net="{}"'.format(network) query += " group by w.net, w.sta, w.loc, w.cha; " diff --git a/seismic/ASDFdatabase/export_station_locations.py b/seismic/ASDFdatabase/export_station_locations.py index 48364d91..776863e8 100644 --- a/seismic/ASDFdatabase/export_station_locations.py +++ b/seismic/ASDFdatabase/export_station_locations.py @@ -89,8 +89,8 @@ def write_csv(rows, ofn): ds = FederatedASDFDataSet(asdf_source) - query = 'select ns.net, ns.sta, ns.lon, ns.lat from netsta as ns, wdb as wdb ' \ - 'where ns.net=wdb.net and ns.sta=wdb.sta ' + query = 'select ns.net, ns.sta, ns.lon, ns.lat from meta as ns, wtag as wt ' \ + 'where ns.net=wt.net and ns.sta=wt.sta ' if (network): query += ' and ns.net="{}" '.format(network) @@ -101,15 +101,15 @@ def write_csv(rows, ofn): # end if if (location): - query += ' and wdb.loc="{}" '.format(location) + query += ' and wt.loc="{}" '.format(location) # end if if (channel): - query += ' and wdb.cha="{}" '.format(channel) + query += ' and wt.cha="{}" '.format(channel) # end if if (start_date_ts and end_date_ts): - query += ' and wdb.st>={} and wdb.et<={}'.format(start_date_ts, end_date_ts) + query += ' and wt.st>={} and wt.et<={}'.format(start_date_ts, end_date_ts) # end if query += ' group by ns.net, ns.sta' diff --git a/tests/test_seismic/ASDFdatabase/test_federatedasdfdataset.py b/tests/test_seismic/ASDFdatabase/test_federatedasdfdataset.py index f5b7c7e1..8dde6cf8 100644 --- a/tests/test_seismic/ASDFdatabase/test_federatedasdfdataset.py +++ b/tests/test_seismic/ASDFdatabase/test_federatedasdfdataset.py @@ -57,7 +57,7 @@ def test_db_integrity(): # get number of waveforms from the db directly conn = sqlite3.connect(fds.fds.db_fn) - query = 'select count(*) from wdb;' + query = 'select count(*) from wtag;' db_waveform_count = conn.execute(query).fetchall()[0][0] # fetch waveform counts for each unique combination of net, sta, loc, cha