Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update fedasdf #260

Merged
merged 3 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
File renamed without changes.
File renamed without changes.
78 changes: 40 additions & 38 deletions seismic/ASDFdatabase/_FederatedASDFDataSetImpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import sqlite3
import hashlib
from functools import partial
from seismic.ASDFdatabase.utils import MIN_DATE, MAX_DATE, cleanse_inventory
from seismic.ASDFdatabase.utils import MIN_DATE, MAX_DATE, cleanse_inventory, InventoryAggregator
from seismic.misc import split_list, setup_logger
import pickle as cPickle
import pandas as pd
Expand Down Expand Up @@ -334,22 +334,27 @@ def decode_tag(tag, type='raw_recording'):
self.comm.Barrier()

if(dbFound):
print(('Found database: %s'%(self.db_fn)))
print('Found database: %s'%(self.db_fn))
self.conn = sqlite3.connect(self.db_fn,
check_same_thread=self.single_threaded_access)
else:
if(self.rank==0):
ia = InventoryAggregator()

self.conn = sqlite3.connect(self.db_fn,
check_same_thread=self.single_threaded_access)
self.conn.execute('create table wdb(ds_id smallint, net varchar(6), sta varchar(6), loc varchar(6), '
self.conn.execute('create table ds(ds_id smallint, path text)')
self.conn.execute('create table wtag(ds_id smallint, net varchar(6), sta varchar(6), loc varchar(6), '
'cha varchar(6), st double, et double, tag text)')
self.conn.execute('create table netsta(ds_id smallint, net varchar(6), sta varchar(6), lon double, '
self.conn.execute('create table meta(ds_id smallint, net varchar(6), sta varchar(6), lon double, '
'lat double, elev_m double)')
self.conn.execute('create table masterinv(inv blob)')

metadatalist = []
masterinv = None
for ids, ds in enumerate(self.asdf_datasets):
self.conn.execute('insert into ds(ds_id, path) values(?, ?)',
[ids, self.asdf_file_names[ids]])

coords_dict = ds.get_all_coordinates()

# report any missing metadata
Expand All @@ -362,26 +367,22 @@ def decode_tag(tag, type='raw_recording'):
# end if

for k in coords_dict.keys():
if(not masterinv):
masterinv = ds.waveforms[k].StationXML
else:
try:
masterinv += cleanse_inventory(ds.waveforms[k].StationXML)
except Exception as e:
print(e)
# end try
# end if
# end for

for k in list(coords_dict.keys()):
# we keep coordinates from all ASDF files to be able to track
# potential discrepancies
lon = coords_dict[k]['longitude']
lat = coords_dict[k]['latitude']
elev_m = coords_dict[k]['elevation_in_m']
nc, sc = k.split('.')
metadatalist.append([ids, nc, sc, lon, lat, elev_m])

# aggregate inventories
inv = cleanse_inventory(ds.waveforms[k].StationXML)
ia.append(inv)
# end for
# end for
self.conn.executemany('insert into netsta(ds_id, net, sta, lon, lat, elev_m) values '

masterinv = ia.summarize()
self.conn.executemany('insert into meta(ds_id, net, sta, lon, lat, elev_m) values '
'(?, ?, ?, ?, ?, ?)', metadatalist)
self.conn.execute('insert into masterinv(inv) values(?)',
[cPickle.dumps(masterinv, cPickle.HIGHEST_PROTOCOL)])
Expand All @@ -392,10 +393,11 @@ def decode_tag(tag, type='raw_recording'):
self.conn.commit()
self.conn.close()
# end if
self.comm.Barrier()

tagsCount = 0
for ids, ds in enumerate(self.asdf_datasets):
if(self.rank==0): print(('Indexing %s..' % (os.path.basename(self.asdf_file_names[ids]))))
if(self.rank==0): print('Indexing %s..' % (os.path.basename(self.asdf_file_names[ids])))

keys = list(ds.get_all_coordinates().keys())
keys = split_list(keys, self.nproc)
Expand All @@ -420,10 +422,10 @@ def decode_tag(tag, type='raw_recording'):
if(len(data)):
self.conn = sqlite3.connect(self.db_fn,
check_same_thread=self.single_threaded_access)
self.conn.executemany('insert into wdb(ds_id, net, sta, loc, cha, st, et, tag) values '
self.conn.executemany('insert into wtag(ds_id, net, sta, loc, cha, st, et, tag) values '
'(?, ?, ?, ?, ?, ?, ?, ?)', data)
print(('\tInserted %d entries on rank %d'%(len(data),
self.rank)))
print('\tInserted %d entries on rank %d'%(len(data),
self.rank))
tagsCount += len(data)
self.conn.commit()
self.conn.close()
Expand All @@ -438,8 +440,8 @@ def decode_tag(tag, type='raw_recording'):
print('Creating table indices..')
self.conn = sqlite3.connect(self.db_fn,
check_same_thread=self.single_threaded_access)
self.conn.execute('create index allindex on wdb(net, sta, loc, cha, st, et)')
self.conn.execute('create index netstaindex on netsta(ds_id, net, sta)')
self.conn.execute('create index allindex on wtag(ds_id, net, sta, loc, cha, st, et)')
self.conn.execute('create index metaindex on meta(ds_id, net, sta)')
self.conn.commit()
self.conn.close()
print('Done..')
Expand All @@ -450,7 +452,7 @@ def decode_tag(tag, type='raw_recording'):
# end if

# Load metadata
rows = self.conn.execute('select * from netsta').fetchall()
rows = self.conn.execute('select * from meta').fetchall()
for row in rows:
ds_id, net, sta, lon, lat, elev_m = row
self.asdf_station_coordinates[ds_id]['%s.%s' % (net.strip(), sta.strip())] = [lon, lat, elev_m]
Expand All @@ -462,7 +464,7 @@ def decode_tag(tag, type='raw_recording'):
# end func

def get_global_time_range(self, network, station=None, location=None, channel=None):
query = "select min(st), max(et) from wdb where net='%s' "%(network)
query = "select min(st), max(et) from wtag where net='%s' "%(network)

if (station is not None):
query += "and sta='%s' "%(station)
Expand All @@ -488,7 +490,7 @@ def get_stations(self, starttime, endtime, network=None, station=None, location=
starttime = UTCDateTime(starttime).timestamp
endtime = UTCDateTime(endtime).timestamp

query = 'select * from wdb where '
query = 'select * from wtag where '
if (network): query += " net='%s' "%(network)
if (station):
if(network): query += "and sta='%s' "%(station)
Expand Down Expand Up @@ -522,7 +524,7 @@ def get_waveform_count(self, network, station, location, channel, starttime, end
starttime = UTCDateTime(starttime).timestamp
endtime = UTCDateTime(endtime).timestamp

query = "select count(*) from wdb where net='%s' and sta='%s' and loc='%s' and cha='%s' " \
query = "select count(*) from wtag where net='%s' and sta='%s' and loc='%s' and cha='%s' " \
%(network, station, location, channel) + \
"and et>=%f and st<=%f" \
% (starttime, endtime)
Expand All @@ -538,7 +540,7 @@ def get_waveforms(self, network, station, location, channel, starttime,
starttime = UTCDateTime(starttime)
endtime = UTCDateTime(endtime)

query = "select * from wdb where net='%s' and sta='%s' and loc='%s' and cha='%s' " \
query = "select * from wtag where net='%s' and sta='%s' and loc='%s' and cha='%s' " \
%(network, station, location, channel) + \
"and et>=%f and st<=%f" \
% (starttime.timestamp, endtime.timestamp)
Expand Down Expand Up @@ -639,14 +641,14 @@ def stations_iterator(self, network_list=[], station_list=[]):
workload.append(defaultdict(partial(defaultdict, list)))
# end for

nets = self.conn.execute('select distinct net from wdb').fetchall()
nets = self.conn.execute('select distinct net from wtag').fetchall()
if(len(network_list)): # filter networks
nets = [net for net in nets if net[0] in network_list]
# end if

for net in nets:
net = net[0]
stas = self.conn.execute("select distinct sta from wdb where net='%s'"%(net)).fetchall()
stas = self.conn.execute("select distinct sta from wtag where net='%s'"%(net)).fetchall()

if (len(station_list)): # filter stations
stas = [sta for sta in stas if sta[0] in station_list]
Expand All @@ -656,7 +658,7 @@ def stations_iterator(self, network_list=[], station_list=[]):
sta = sta[0]

# trace-count, min(st), max(et)
attribs = self.conn.execute("select count(st), min(st), max(et) from wdb where net='%s' and sta='%s'"
attribs = self.conn.execute("select count(st), min(st), max(et) from wtag where net='%s' and sta='%s'"
%(net, sta)).fetchall()

if(len(attribs)==0): continue
Expand Down Expand Up @@ -710,7 +712,7 @@ def find_gaps(self, network=None, station=None, location=None,
min_gap_length=86400):

clause_added = 0
query = 'select net, sta, loc, cha, st, et from wdb '
query = 'select net, sta, loc, cha, st, et from wtag '
if (network or station or location or channel or (start_date_ts and end_date_ts)): query += " where "

if (network):
Expand Down Expand Up @@ -818,11 +820,11 @@ def find_gaps(self, network=None, station=None, location=None,
def get_coverage(self, network=None):
query = """
select w.net, w.sta, w.loc, w.cha, n.lon, n.lat, min(w.st), max(w.et)
from wdb as w, netsta as n where w.net=n.net and w.sta=n.sta and
w.net in (select distinct net from netsta) and
w.sta in (select distinct sta from netsta) and
w.loc in (select distinct loc from wdb) and
w.cha in (select distinct cha from wdb)
from wtag as w, meta as n where w.net=n.net and w.sta=n.sta and
w.net in (select distinct net from meta) and
w.sta in (select distinct sta from meta) and
w.loc in (select distinct loc from wtag) and
w.cha in (select distinct cha from wtag)
"""
if(network): query += ' and w.net="{}"'.format(network)
query += " group by w.net, w.sta, w.loc, w.cha; "
Expand Down
10 changes: 5 additions & 5 deletions seismic/ASDFdatabase/export_station_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def write_csv(rows, ofn):

ds = FederatedASDFDataSet(asdf_source)

query = 'select ns.net, ns.sta, ns.lon, ns.lat from netsta as ns, wdb as wdb ' \
'where ns.net=wdb.net and ns.sta=wdb.sta '
query = 'select ns.net, ns.sta, ns.lon, ns.lat from meta as ns, wtag as wt ' \
'where ns.net=wt.net and ns.sta=wt.sta '

if (network):
query += ' and ns.net="{}" '.format(network)
Expand All @@ -101,15 +101,15 @@ def write_csv(rows, ofn):
# end if

if (location):
query += ' and wdb.loc="{}" '.format(location)
query += ' and wt.loc="{}" '.format(location)
# end if

if (channel):
query += ' and wdb.cha="{}" '.format(channel)
query += ' and wt.cha="{}" '.format(channel)
# end if

if (start_date_ts and end_date_ts):
query += ' and wdb.st>={} and wdb.et<={}'.format(start_date_ts, end_date_ts)
query += ' and wt.st>={} and wt.et<={}'.format(start_date_ts, end_date_ts)
# end if

query += ' group by ns.net, ns.sta'
Expand Down
75 changes: 74 additions & 1 deletion seismic/ASDFdatabase/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import os
from tqdm import tqdm
from ordered_set import OrderedSet as set
from seismic.misc import split_list
from obspy import Inventory
import obspy
import copy

MAX_DATE = UTCDateTime(4102444800.0) #2100-01-01
MIN_DATE = UTCDateTime(-2208988800.0) #1900-01-01
Expand All @@ -30,6 +31,78 @@ def cleanse_inventory(iinv: Inventory) -> Inventory:
return oinv
# end func

class InventoryAggregator:
def __init__(self):
tree = lambda: defaultdict(tree)
self.net_dict = tree()
self.sta_dict = tree()
self.cha_dict = tree()
# end func

def append(self, inv: Inventory):
for net in inv.networks:
nc = net.code

if (type(self.net_dict[nc]) == defaultdict):
onet = copy.deepcopy(net)
onet.stations = []
self.net_dict[nc] = onet
# end if

for sta in net.stations:
sc = sta.code

if (type(self.sta_dict[nc][sc]) == defaultdict):
osta = copy.deepcopy(sta)
osta.channels = []
self.sta_dict[nc][sc] = osta
# end if

for cha in sta.channels:
cc = cha.code
lc = cha.location_code

# set responses to None
try:
cc.response = None
except:
pass

if (type(self.cha_dict[nc][sc][lc][cc]) == defaultdict):
self.cha_dict[nc][sc][lc][cc] = cha
# end if
# end for
# end for
# end for
# end func

def summarize(self):
oinv = Inventory(networks=[],
source=obspy.core.util.version.read_release_version())

for nc in self.net_dict.keys():
net = self.net_dict[nc]

for sc in self.sta_dict[nc].keys():
sta = self.sta_dict[nc][sc]

for lc in self.cha_dict[nc][sc].keys():
for cc in self.cha_dict[nc][sc][lc].keys():
cha = self.cha_dict[nc][sc][lc][cc]

sta.channels.append(cha)
# end for
# end for
net.stations.append(sta)
# end for

oinv.networks.append(net)
# end for

return oinv
# end func
# end class

class MseedIndex:
class StreamCache:
def __init__(self):
Expand Down
1 change: 1 addition & 0 deletions seismic/extract_event_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def asdf_get_waveforms(asdf_dataset, network, station, location, channel, startt
matching_stations = asdf_dataset.get_stations(starttime, endtime, network=network, station=station,
location=location)
if matching_stations:
channel = channel.replace('?', '.') # replace greedy matching by single-character matching
ch_matcher = re.compile(channel)
for net, sta, loc, cha, _, _, _ in matching_stations:
if ch_matcher.match(cha):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_db_integrity():

# get number of waveforms from the db directly
conn = sqlite3.connect(fds.fds.db_fn)
query = 'select count(*) from wdb;'
query = 'select count(*) from wtag;'
db_waveform_count = conn.execute(query).fetchall()[0][0]

# fetch waveform counts for each unique combination of net, sta, loc, cha
Expand Down
Loading