-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
1973 lines (1564 loc) · 81.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#Some useful functions that I do not want to put elsewhere
import numpy as np
import xarray as xr
from bisect import bisect_left
import bisect
import dataloader
import psm_pseudoproxy
import evaluation
import tqdm
import warnings
from numba import njit,prange
import os
import pdb
def config_check(cfg):
"""
convert configuration settings that should be a list into a list:
metrics, psm,obsdata,pp_var,
checks timeresolutions, length of obsdata/psm
#Does not check everything!
"""
if not isinstance(cfg['metrics'],list):
#string conversion guarantees that None is skipped
cfg['metrics']=[str(cfg['metrics'])]
if not isinstance(cfg['psm'],list):
cfg['psm']=[cfg['psm']]
if not isinstance(cfg['obsdata'],list):
cfg['obsdata']=[cfg['obsdata']]
if not isinstance(cfg['proxy_error'],list):
cfg['proxy_error']=[cfg['proxy_error']]
if not isinstance(cfg['time_scales'],list):
cfg['time_scales']=[cfg['time_scales']]
#check length of proxy databases and defined psm.
if len(cfg['obsdata'])!=len(cfg['psm']):
#when only one None given for psm extend it to a list of correct length
if len(cfg['psm'])==1:
if cfg['psm'][0]==None:
cfg['psm']=[None for i in range(len(cfg['obsdata']))]
else:
pass
#raise TypeError('obsdata- and psm- configuration are not of same length')
##check proxy error config
if len(cfg['proxy_error'])!=len(cfg['obsdata']):
if len(cfg['proxy_error'])==1:
#repeat first value
val=cfg['proxy_error'][0]
cfg['proxy_error']=[val for i in range(len(cfg['obsdata']))]
else:
pass
#raise TypeError('obsdata- and proxy_error- configuration are not of same length')
#check timescale config
if len(cfg['time_scales'])!=len(cfg['obsdata']):
#when only one None given for psm extend it to a list of correct length
if len(cfg['time_scales'])==1:
#repeat first value
val=cfg['time_scales'][0]
cfg['time_scales']=[val for i in range(len(cfg['obsdata']))]
else:
pass
#raise TypeError('obsdata- and time_scale- configuration are not of same length')
#in case we use pseudoproxies do addtional adjustments
if cfg['ppe']['use']==True:
if not isinstance(cfg['ppe']['metrics_ppe'],list):
cfg['ppe']['metrics_ppe']=[cfg['ppe']['metrics_ppe']]
###multi-timescale timescale list: sorting and checking
ts_list=cfg['timescales']
#convert to list
if not isinstance(ts_list,list) and not isinstance(ts_list,np.ndarray):
ts_list=[ts_list]
print(ts_list)
#convert to integers
ts_list=list(map(int,ts_list))
#sort
ts_list=np.sort(ts_list)
cfg['timescales']=np.array(ts_list)
#check if all resolutions divisors of block length
bl=ts_list[-1] #last element is block length
for t in ts_list[:-1]:
assert bl%t==0, 'Time resolution {t} is not a divisor of {bl} '.format(t=t,bl=bl)
return cfg
def prior_preparation(cfg):
"""
load_priors and take annual mean
save all variables in one Dataset and separately save attribues
also save the raw (monthly) prior, in case it is needed for the PSM
"""
prior=[]
prior_raw=[]
for v,p in cfg.vp.items():
#load only variables which have path not set to None
if p!=None:
print('Load',v,' from ',p)
print('computing yearly average')
data=dataloader.load_prior(p,v)
#compute annual mean ac
data_m=dataloader.annual_mean(data,avg=cfg.avg,check_nan=cfg.check_nan)
prior.append(data_m)
prior_raw.append(data)
#create Dataset (This will always be a Dataset not Dataarray, also for only one variable)
prior=xr.merge(prior)
prior_raw=xr.merge(prior_raw)
#copy attributes
attributes=[]
for v in prior.keys():
a=prior[v].attrs
attributes.append(a)
return prior, attributes, prior_raw
def proxy_load(c):
"""
#load proxies, their locations and the time-axis
#this time-axis is the one used for looking at the proxy-timeresolution, the resampling process.
#create a time_range given c.proxy_time, starts at beginning of year
"""
print('>>>>>LOADING PROXY DATA')
#Also replaces the error by c.proxy_error (fixed one for all proxies)
start=c.proxy_time[0]
end=c.proxy_time[1]
time=xr.cftime_range(start=start,end=end,freq='YS',calendar='365_day')
pp_y_all=[]
pp_r_all=[]
#loop over each database
for ip,p in enumerate(c.obsdata):
#load proxy database, slice time and extract lat and lon
proxy_db=xr.open_dataset(p,use_cftime=True)#.squeeze(drop=True)
#eventually limit proxies geographically (its probably wiser to create a separate database for proxies of a region and study them first)
if c.only_regional_proxies==True:
bounds=c.regional_bounds #[[latS,latN],[latW,latE]]
if bounds!=False and bounds!=None:
#latitudes selection
proxies=proxy_db
proxy_db=proxy_db.where((proxy_db.lat>= bounds[0][0] ) & (proxy_db.lat <= bounds[0][1]), drop=True)
lat=proxies.lat
lon=proxies.lon
lats,lons=bounds[0],bounds[1]
proxies=proxies.where( (proxies.lat >= lats[0] ) & (proxies.lat <= lats[1]), drop=True)
if lons[0]<lons[1]: # normal selection of longitudes
proxies=proxies.where( (proxies.lon >= lons[0] ) & (proxies.lon <= lons[1]), drop=True)
else: #cross zero-meridian region option
lon=proxies.lon
sel_lon_1 = lon.where((lon >= lons[0] ), drop=True)
sel_lon_2 = lon.where((lon <= lons[1]), drop=True)
sel_lon=np.concatenate([sel_lon_1,sel_lon_2])
#workaround!
prox1=proxies.where((lon <= lons[1] ),drop=True)
prox2=proxies.where((lon >= lons[0] ),drop=True)
proxies=xr.merge([prox1,prox2])
#option to select only a fraction of the proxies from the beginning (meant for PPEs)
try:
if c.how_many is not None:
prox_mems=c.how_many[ip] # absolute number
if prox_mems>len(proxy_db.site):
prox_mems=len(proxy_db.site)
prox_idx=dataloader.random_indices(prox_mems,len(proxy_db.site), reps=1,seed=c.seed)
proxy_db=proxy_db.isel(site=prox_idx[0])
except:
pass
proxy_db=proxy_db.sel(time=slice(time[0],time[-1]))
lat=proxy_db.lat
lon=proxy_db.lon
#reindex times (adds nans correctly at beginning/end, slice doesn't extend range)
proxy_db=proxy_db.reindex({'time':time},method='nearest')
#load proxy values, can also serve as a time mask when useing PPE
pp_y=proxy_db[c.obs_var[ip][0]]
if c.psm[ip]!='linear':
pp_r=proxy_db[c.obs_var[ip][1]]
#eventually take the linear regression error when working with a linear PSM
else:
pp_r=proxy_db[c.linear['error']]
#drop sites that do not contain values after the time reindexing
#keeping them would lead to problems in multi timescale
for idx,s in enumerate(pp_y.site):
#drop locations without records, but only when working with realproxies
if c.ppe['use']==True:
avail_times=pp_y.sel(site=s).time.values
else:
avail_times=pp_y.sel(site=s).dropna('time').time.values
if len(avail_times)==0:
pp_y=pp_y.drop_sel(site=s)
pp_r=pp_r.drop_sel(site=s)
lon=lon.drop_sel(site=s)
lat=lat.drop_sel(site=s)
proxy_db=proxy_db.drop_sel(site=s)
#ERROR value replacing
#filling proxy dataarray like this is the wy to go, because there is no 2d
#indexing for dataarrays afaik
elif c.proxy_error is not None:
pp_r.loc[dict(site=s,time=avail_times)]=np.ones(len(avail_times))*c.proxy_error[ip]
#replace site number by string
sites=[str(ip)+'.'+str(s.values.tolist()) for s in proxy_db['site']]
pp_y['site']=sites
pp_r['site']=sites
#pp_y_lon['site']=sites
#pp_y_lat['site']=sites
pp_y_all.append(pp_y)
pp_r_all.append(pp_r)
#pp_y_lon.append(lon.values)
#pp_y_lat.append(lat.values)
#not sure if necessary, but jupyter notebooks started to slow down a lot.
del proxy_db
#return pp_y_all,pp_r_all,pp_y_lon,pp_y_lat
return pp_y_all,pp_r_all
def proxy_timeres(c,pp_y_all):
"""
For each proxy compute the time resolution (1 value). Resolution mode given by cfg['time_scales'] (e.g. mean, median, min...).
The resolution is round to the next largest available resolution
In principle one could also attribute one resolution to each timestep, but this doesn't fit to our resampling procedure later, so this part is commented out.
Will maybe reuse that later. This is why the code still contains part which would produce a 2-D output (time,site) instead
of just (site)-dimensional as it is now
---------
Input:
c: config (as namespace)
pp_y_all: List of DataArrays for each proxy_database
Return:
time_res_list: List of DataArray, one DataArray for each proxy database with 1 timeres value (site coordinate is kept)
"""
#timescales have already been sorted in the config_check step
timescales=c.timescales
timeres_list=[]
for ip,db in enumerate(pp_y_all):
res=db.copy(deep=True)
resols=xr.DataArray(np.zeros(len(db.site)),coords=dict(site=db.site))
mode=c.time_scales[ip]
#loop over sites
for i,s in enumerate(res.site):
times=db.sel(site=s).dropna('time').time
years=times.time.dt.year
if len(years)==1:
print('only one record for site ',s.values,'. Giving it timescale 1')
resols[i]=1
#res.loc[dict(site=s,time=times)]=1
else:
#compute distance to the right/left and double the one value at the respective end which is missing
dist_right=np.abs(np.array(years)[1:]-np.array(years)[:-1])
#dist_right=np.concatenate([dist_right,np.array([dist_right[-1]])])
#dist_left=np.abs(np.array(years)[:-1]-np.array(years)[1:])
#dist_left=np.concatenate([np.array([dist_left[0]]),dist_left])
#for constant resolutions repeat resolution according to number of years
if mode== 'min':
res=dist_right.min()
#resols=dist_right[:-1].min()*np.ones(len(years))
elif mode=='mean':
res=dist_right.mean()
#resols=dist_right[:-1].mean()*np.ones(len(years))
elif mode=='median':
res=np.median(dist_right)
#elif mode=='most':
# values,counts=np.unique(dist_right[:-1],return_counts=True)
# ind=np.argmax(counts)
# resols=values[ind]*np.ones(len(years))
#number corresponds to prescribed mode
elif isinstance(mode,float) | isinstance(mode,int) :
res=mode
#resols=mode*np.ones(len(years))
#elif mode=='rl_max':
# resols=np.array([dist_right,dist_left]).max(axis=0)
#elif mode=='rl_min':
# resols=np.array([dist_right,dist_left]).min(axis=0)
#elif mode=='rl_mean':
# resols=np.array([dist_right,dist_left]).mean(axis=0)
else:
import sys
sys.exit("Time resolution mode unknown. Exit.")
#Round:
#Check that time_res is not larger than largest: -> if so, assign to largest
if res > timescales[-1]:
resols[i]=timescales[-1]
else:
resols[i]=timescales[bisect_left(timescales,res)]
#eventually round estimated time_scales to ones predefined by algorithm (c.multi_timescale['timescales'])
#if c.round_time_scales:
# eps=1e-10 #epsilon needed for ceil rounding (3 goes to 5 instead of 1)
# resols=np.array([min(scales, key=lambda x:abs(x-r+eps)) for r in resols])
timeres_list.append(resols)
return timeres_list
def resample_proxies(c,timeres_list,times_list,pp_y_all):
"""
Resampling procedure for each proxy:
- Upsample proxy timeseries to yearly resolution. Use 'nearest'/'linear' interpolation for the
nans in between
- Lowpass filter this timeseries (no filtering if the targetresolution is small than 4 years <- side effects?)
- Resample to target time_series
- mask holes in original time series according to the cfg['mask'] factor
The proxies are brought together into a list of DataArrays.
The sites are given a prefix to distinguish from which proxy-db they are
"""
mask_ = c.mask #masking tolerance factor (mask_ * time_res is max. gap size)
mode=c.resample_mode
timescales=np.array(c.timescales) #make sure it's really a numpy array
#create list of lists for each proxy_db and each timescale
lisst=[]
#store data in dictionary for all proxies in this database
dictionary={}
for scale in timescales:
dictionary[str(scale)]=dict(ts=[],sites=[])
#loop over proxy dbs
print('resampling of proxies in each database')
for i, db in enumerate(pp_y_all):
timeres_vals=timeres_list[i].values
for ii,s in enumerate(tqdm.tqdm(db.site)):
#proxy data
data=db.sel(site=s)
#timeresolution for this one proxy
res=int(timeres_vals[ii])
#look up targettimeseries for this resolution (It's fast)
idx=int(np.argwhere(timescales==res).flatten())
target_time=times_list[idx]
#resample. If res <4 don't use the lowpass filter.
if res<4:
filt=False
else:
filt=True
resampled=make_equidistant_target(data,target_time,target_res=res,method_interpol=mode,filt=filt,min_ts=1,)
#mask the gaps
resampled=mask_the_gap_alt(resampled,data, time_res=res,tol=mask_)
#add to dictionary
dictionary[str(res)]['ts'].append(resampled.values)
#create site with prefix indicating the database.
#site=float(str(i)+'.'+str(s.values.tolist()))
#Keep site as string, else 0.1 will be=0.10
dictionary[str(res)]['sites'].append(s.values.tolist())
if c.reuse==True:
#loop over all the other timescales to the right
for t_ii,t_i in enumerate(timescales[idx+1:]):
res=int(t_i)
#targettimeseries for this resolution
idx=int(np.argwhere(timescales==res).flatten())
target_time=times_list[idx]
#resample. If res <4 don't use the lowpass filter.
if res<4:
filt=False
else:
filt=True
resampled=make_equidistant_target(data,target_time,target_res=res,method_interpol=mode,filt=filt,min_ts=1,)
#mask the gaps
resampled=mask_the_gap_alt(resampled,data, time_res=res,tol=mask_)
#add to dictionary
dictionary[str(res)]['ts'].append(resampled.values)
dictionary[str(res)]['sites'].append(s.values.tolist())
lisst.append(dictionary)
#loop over dictionary and bring together
final_list=[]
for ii,(i,dic) in enumerate(dictionary.items()):
#hack for timescales without any value: create a dummy record that only has nans (will not be used)
#does not work! because of multi timescale ...
"""
if len(dic['sites'])==0:
#get length of
length=len(times_list[ii])
vals=[np.nan*np.ones(length)]
sites=['0.999']
else:
"""
vals=np.stack(dic['ts'])
sites=dic['sites']
idx=int(np.argwhere(timescales==int(i)).flatten())
target_time=times_list[idx]
data_array=xr.DataArray(vals,coords=dict(site=sites,time=target_time))
#We add an attribute to each time-series to have the number of proxies per database directly accesible
#convert sites floats and the to integers, count occurence
integers=(list(map(int,list(map(float,sites)))))
data_array.attrs['DB_members']=np.unique(integers,return_counts=True)[1]
final_list.append(data_array.transpose('time','site'))
return final_list
def make_equidistant_target(data,target_time,target_res,method_interpol='nearest',filt=True,min_ts=1):
"""
Takes a proxy timeseries "data" (fully resolved,with nans in between if no value available) and resamples it equidistantly to the (equidistant) target timeseries
"target_time" (DataArray of cftime-objects, we need .dt.time.year accessor), which has the resolution "target_res" (consistency with target_time is not checked).
We usualy set the target_res to the median resolution.
The resampling procedure is adapted from the Paleospec R-package: https://github.com/EarthSystemDiagnostics/paleospec/blob/master/R/MakeEquidistant.R.
The time resolution is based on yearly data. Other time resolution (monthly) would require adapting the filtering part.
Code consists of the following steps.
0. Duplicate first non-nan-data point if this required by target_time spacing
1. Resample and interpolate time series to 'min_ts'-resolution (yearly makes sense in our case). Nearest neighbor interpolation!
2. Low_pass filter resampled time_series (High order Butterworth filter used in original R-package, I use filtfilt to avoid time lag)
3. Resample to target resolution
Comments:
1. Be aware that some proxy records have huge jumps without data in between. The resampled values there are not meaningful and need to be masked separately.
2. Use xarray > v2022.06.0 to make use of fast resampling operation (but slowness of old version not a problem for our time-lengths)
Example:
Given some time-series with measurements at time [4,9,14,19,24], which we treat as mean-values for the time range centered on these times.
We want to resample it equidistanced for the times [0,5,10,15,20,25]. These target labels are actually the left edge of a time block
(in the DA we effectively reconstruct the mean of the years [[0,1,2,3,4],[5,6,7,8,9],[10,11,12,13,14],[15,16,17,18,19],[20,21,22,23,24]])
Therefore when using the xarray resample method in the final step (down-sampling) it is important to set closed='left' (and eliminate the last element), so that it is logically consistent.
"""
#drop nan entries in data, extract data which are not nan
#without dropping nans interpolation wont work
data=data.dropna('time')
vals=data.values
#time_values and years contained in original time_series
time=data.time.values
time_years=data.time.dt.year.values
#For the first year included in proxy timeseries find the nearest year in target_time, which is smaller than the first year.
#Repeat first data value and append this value and its time to the values. Do not do this if the first year is part of the target_time.
first_year=time_years[0]
target_years=target_time.time.dt.year.values
#find by modulo calcuation and search sorted (could also create new Datetimeobject)
#take into consideration the start time that might be shifted
start=first_year-first_year % target_res + target_years[0]%target_res
if start!=first_year:
idx = np.searchsorted(target_years, first_year, side="left")
time_add=target_time[idx-1].values
#insert time and duplicate first value
time=np.insert(time,0,time_add)
vals=np.insert(vals,0,vals[0])
vals_new=xr.DataArray(data=vals,coords=dict(time=time))
#1. resampling (upsampling) and interpolating (upsampling)
min_ts=str(min_ts)+'YS'
try:
upsampled=vals_new.resample(time=min_ts).interpolate(method_interpol)
except:
if len(vals_new.time)==1:
#case of only one value, then no interpolation.
#already have correct time, checked in start!=first_year
upsampled=vals_new
else:
import pdb
pdb.set_trace()
##Fill nans (already done in previous step)
#upsampled=upsampled.interpolate_na('time',method='linear')
#2. LOW PASS FILTER for resampled time series (avoid aliasing)
from scipy.signal import butter, lfilter, filtfilt
def butter_lowpass(cutoff, fs, order=6, kf=1.2):
# kf: scaling factor for the lowpass frequency; 1 = Nyquist, 1.2 =
#' 1.2xNyquist is a tradeoff between reducing variance loss and keeping
#' aliasing small
#fs is basic timestep (min_ts)
#nyquist frequency
nyq = 0.5 * fs
normal_cutoff = cutoff / nyq * kf
b, a = butter(order, normal_cutoff, btype='low', analog=False)
return b, a
def butter_lowpass_filter(data, cutoff, fs, order=6):
#filtfilt does not introduce a time-lag in comparison to butterworth
b, a = butter_lowpass(cutoff, fs, order=order)
y = filtfilt(b, a, data)
#y = lfilter(b, a, data)
return y
cutoff=1/target_res
fs=1 #yearly base frequency
#pdb.set_trace()
if filt==True:
try:
up_filt=butter_lowpass_filter(upsampled,cutoff,fs,order=4)
except:
#for short reconstruction time range (e.g.1900-1999), the resampling can not work out, then just take the value as it is
up_filt=upsampled
else:
up_filt=upsampled
###3. RESAMPLE TO TARGET RESOLUTION
#string for resampled option 'YS': year start is very important
target_res_st=str(target_res)+'YS'
#convert up_filt to Dataarray in order to use resample method
up_filt=xr.DataArray(up_filt, coords=dict(time=upsampled['time']))
#pdb.set_trace()
resampled=up_filt.resample(time=target_res_st,closed='left').mean('time')
#reindex time to get back to the global full timescales (non existing values filled with nan)
final=resampled.reindex(time=target_time)
return final
def mask_the_gap_alt(resampled_ts, original_ts, time_res,tol):
"""
Function for masking gaps after the resampling.
It looks for gaps in the original timeseries and masks them in the resampled timeseries.
Input:
resampled: equidistant time-series (from proxy_db beginning to end, containing nans at beginning/start)
time_res: Resolution of resampled time-series
original_ts: original time_series of proxy from proxy_db-table (containing nans in between measurement if measurement not yearly)
tol(erance): size of gap with respect to time_res (tol*time_res), here it is a factor that is multiplied with each timeresolution
"""
#copy
resampled_ts=resampled_ts.copy()
#maximum allowed gap
max_gap=tol*time_res
#screen original timeseries for jumps
original_ts_years=original_ts.dropna('time').time.dt.year
gaps=np.abs(np.array(original_ts_years)[1:]-np.array(original_ts_years)[:-1])
#find index where gap > max_gap (left end)
args=np.argwhere(gaps>max_gap).flatten()
#select according years
#starts=original_ts.dropna('time').time[args].dt.year
#ends=original_ts.dropna('time').time[args+1].dt.year
starts=original_ts_years[args]
ends=original_ts_years[args+1]
target_time_ts=resampled_ts['time']
target_years_ts=resampled_ts['time'].dt.year
#in target years, find the ones that are larger than start and smaller than end
#bisectional search is the most efficient way, list comprehension would be orders of magnitude slower
#we use bisect-righ for findin the elements (right/left indicates if index to right/left is chosen for equality)
#For the end we keep the first to the left of the end (because it's influenced by the measurement to the right) and thus
#select the penultimate one
for ii,t in enumerate(starts):
#find indices with bisect right
start_idx=bisect.bisect_right(target_years_ts,starts[ii])
#end index,-2 because slice also selects last element
end_idx=bisect.bisect_right(target_years_ts,ends[ii])-2
resampled_ts.loc[dict(time=slice(target_time_ts[start_idx],target_time_ts['time'][end_idx]))]=np.nan
return resampled_ts
def noresample_proxies(c,timeres_list,times_list,pp_y_all):
"""
Function for using multi-timescale DA without the (full) proxy resampling scheme.
It just does the necessary conversion to the times from times_list
The proxies are brought together into a list of DataArrays.
The sites are given a prefix ('0.', '1.'...) to distinguish from which proxy-db they are
"""
timescales=np.array(c.timescales)
#create list of lists for each proxy_db
lisst=[]
#store data in dictionary for all proxies in this database
dictionary={}
for scale in timescales:
dictionary[str(scale)]=dict(ts=[],sites=[])
#loop over proxy dbs
for i, db in enumerate(pp_y_all):
timeres_vals=timeres_list[i].values
for ii,s in enumerate(db.site):
#proxy data
data=db.sel(site=s)
#timeresolution for this one proxy
res=int(timeres_vals[ii])
res_str=str(res)+'YS'
#look up targettimeseries for this resolution (It's fast)
idx=int(np.argwhere(timescales==res).flatten())
target_time=times_list[idx]
resampled=data.resample(time=res_str,closed='left').mean()
#add to dictionary
dictionary[str(res)]['ts'].append(resampled.values)
#create site with prefix indicating the database.
#site=float(str(i)+'.'+str(s.values.tolist()))
#Keep site as string, else 0.1 will be=0.10
dictionary[str(res)]['sites'].append(s.values.tolist())
if c.reuse==True:
#loop over all the other timescales to the right
for t_ii,t_i in enumerate(timescales[idx+1:]):
res=int(t_i)
res_str=str(res)+'YS'
#target timeseries for this resolution
idx=int(np.argwhere(timescales==res).flatten())
resampled=data.resample(time=res_str,closed='left').mean()
#add to dictionary
dictionary[str(res)]['ts'].append(resampled.values)
dictionary[str(res)]['sites'].append(s.values.tolist())
lisst.append(dictionary)
#loop over dictionary and bring together
final_list=[]
for i,dic in dictionary.items():
vals=np.stack(dic['ts'])
sites=dic['sites']
idx=int(np.argwhere(timescales==int(i)).flatten())
target_time=times_list[idx]
data_array=xr.DataArray(vals,coords=dict(site=sites,time=target_time))
#We add an attribute to each time-series to have the number of proxies per database directly accesible
#convert sites floats and the to integers, count occurence
integers=(list(map(int,list(map(float,sites)))))
data_array.attrs['DB_members']=np.unique(integers,return_counts=True)[1]
final_list.append(data_array.transpose('time','site'))
return final_list
def psm_apply(c,prior,prior_raw, pp_y_all,other_model=False):
"""
Takes prior and config.
Psm weighted yearly average requires monthly data (prior_raw)
Converts prior values into HXfull (full time-series at proxy locations, corresponding to prior)
PP_y_all needed for proxy site - name and metadata
Options:
-interpolation to nearest grid point (None vs distance weighted)
- No-psm: Nearest/distance weighted variable of type XY
- Speleo:
- weighting: inf, prec, None
- height correction
- fractionation: True/False
- alphas (Tremaine for calcite, Grossman for Aragonite)
- filter: False/True/float (transit time in cave) #standard convolution time is 2.5 years
- Icecore:
- weighting: precipitation, None
- height: True/False (given by orography file)
other_model: this is set to True for the noise_bf_filt option (that is only relevant for d18O)
this option is only used in pseudoproxy_generator function (source external part)
These options are given in c.psm as a dict.
"""
#List where we append the model proxy estimates to for each database
HXfull_all=[]
#eventually pp_r will be exchanged
pp_r=None
pp_r_list=[]
#loop over psms (which corresponds do looping over the proxy dbs, because there is one psm for each proxy db)
for i,psm in enumerate(c.psm):
#extract values
#lats=pp_y_lat[i]
#lons=pp_y_lon[i]
lats=pp_y_all[i].lat.values
lons=pp_y_all[i].lon.values
proxies=pp_y_all[i]
if psm==None:
#extract needed variable from prior
var=c.var_psm[i]
prior_var=prior[var]
#extrapolate according to positions of proxies
HXfull=psm_pseudoproxy.obs_from_model(prior_var,lats,lons,interpol=c.interpol)
elif psm=='linear':
HXfull=psm_pseudoproxy.linear_psm(c, prior, pp_y_all[i])
elif psm=='speleo':
#weighting
print('USING SPELEO PSM')
if c.speleo['weighting']=='inf':
print('>>>>>>>>>>>>>GETTING MONTHLY d18O Data')
d18=psm_pseudoproxy.infilt_weighting(prior_raw['d18O'],prior_raw['prec'],prior_raw['evap'],slice_unfinished=True,check_nan=c.check_nan)
elif c.speleo['weighting']=='prec':
print('>>>>>>>>>>>>>GETTING MONTHLY d18O Data')
d18=psm_pseudoproxy.prec_weighting(prior_raw['d18O'],prior_raw['prec'],slice_unfinished=True,check_nan=c.check_nan)
else:
d18=prior['d18O']
#pdb conversion at the beginning, always!
d18=psm_pseudoproxy.pdb_conversion(d18)
d18=psm_pseudoproxy.obs_from_model(d18,lat=lats,lon=lons,interpol=c.interpol)
#replace site names, else missing
d18['site']=proxies['site']
#add noise! (other model option prevent that this is applied to HXf)
if c.ppe['use']==True:
if c.ppe['noise_bf_filt']==True and other_model==True:
print('Noise added to d18O before filtering')
d18,pp_r=psm_pseudoproxy.pseudoproxies(d18, SNR=c.ppe['SNR'][0],noisetype=c.ppe['noise_type'],seed=c.seed)
tsurf=prior['tsurf']
tsurf=psm_pseudoproxy.obs_from_model(tsurf,lat=lats,lon=lons,interpol=c.interpol)
tsurf['site']=proxies['site']
#height correction
if c.speleo['height']==True:
print('>>>>>>>>>>>>>APPLYING HEIGHT CORRECTION')
oro=xr.open_dataset(c.oro)['oro']
#my obs_from model function is a sel that works with single lats and lons
#obs from model interpolation seems to be important for some speleos in order to not completly
#be off (especially Echam and CESM Model)
oro=psm_pseudoproxy.obs_from_model(oro,lat=lats,lon=lons,interpol='dw')
elev=proxies['elev']
z=(elev-oro)
#Tsurf: -0.65 https://en.wikipedia.org/wiki/Lapse_rate
#d18O take global value: global average -0.28: https://www.ajsonline.org/content/ajs/301/1/1.full.pdf
#most height sensitive speleos are in the himalay, where the lapse rate is smaller
d18= d18 + -0.28/100*z
tsurf = tsurf + -0.5/100*z
#karst filter
if c.speleo['filter']==True:
print('>>>>>>>>>>>>>APPLYING KARST FILTER')
#following the PRYSM PSM by S. Dee (2015). Transit time 2.5 as in Bühler 2021
#for individual transit time it would be easiest to adapt the proxy_db DataSet with some additional coordinate as metadata
#add noise before the filtering process for pseudoproxies, applied to d18O and temp
# WARNING: only works with a uniform SNR for all databases, and does not make sense for multi-timescale!
#hacky solution to make experiments for seeing how much information is lost due to filtering
#the noise is recalculated afterwards to not make output of this function too clumsy (can be done when c.seed is set)
#noise not added to tsurf (but could be changed)
#tsurf,_=psm_pseudoproxy.pseudoproxies(temp, SNR=c.ppe['SNR'][0],noisetype=c.ppe['noise_type'],seed=c.seed)
tau0=c.speleo['t_time']
#set timeseries
tau=np.arange(len(d18.time))
#Green's function
#well-mixed model
g = (1./tau0) * np.exp(-tau/tau0)
#normalize tau (as we have yearly spaced values, we just sum up all values)
g=g/np.sum(g)
for s in d18.site:
#convolve d18O with g
#subtract mean
vals=d18.sel(site=s)
mean=vals.mean('time').values
#get time axis number
ax=vals.get_axis_num('time')
#convolve padding first/last value (no-problem as g decreases very quickly anyqay)
conv=np.apply_along_axis(lambda m: np.convolve(m, g), axis=ax, arr=(vals-mean).values)[:len(vals)]
#exchange values in initial array
d18.loc[dict(site=s)]=conv+mean
#fractionation (separated treatment of aragonite and calcite)
if c.speleo['fractionation']==True:
print('>>>>>>>>>>>>>APPLYING FRACTIONATION')
#distinguish between aragonite/non-aragonite sites, assume non-aragonite is calcite
arag_sites=proxies.where(proxies['mineralogy']=='aragonite',drop=True).site
calc_sites=proxies.where(proxies['mineralogy']!='aragonite',drop=True).site
#applying the mean tsurf, not tsurf itself, as else the covariance pattern is seriously reduced!
if c.speleo['fractionation_temp']=='mean':
print('use mean temperature')
#d18_calc=psm_pseudoproxy.frac(d18.sel(site=calc_sites),tsurf.sel(site=calc_sites).mean('time'),psm_pseudoproxy.pdb_coplen,psm_pseudoproxy.alpha_calc_trem)
#d18_arag=psm_pseudoproxy.frac(d18.sel(site=arag_sites),tsurf.sel(site=arag_sites).mean('time'),psm_pseudoproxy.pdb_coplen,psm_pseudoproxy.alpha_arag_grossman)
d18_calc=psm_pseudoproxy.frac(d18.sel(site=calc_sites),tsurf.sel(site=calc_sites).mean('time'),psm_pseudoproxy.alpha_calc_trem)
d18_arag=psm_pseudoproxy.frac(d18.sel(site=arag_sites),tsurf.sel(site=arag_sites).mean('time'),psm_pseudoproxy.alpha_arag_grossman)
elif c.speleo['fractionation_temp']=='regular':
print('use time-varying temperature')
#d18_calc=psm_pseudoproxy.frac(d18.sel(site=calc_sites),tsurf.sel(site=calc_sites),psm_pseudoproxy.pdb_coplen,psm_pseudoproxy.alpha_calc_trem)
#d18_arag=psm_pseudoproxy.frac(d18.sel(site=arag_sites),tsurf.sel(site=arag_sites),psm_pseudoproxy.pdb_coplen,psm_pseudoproxy.alpha_arag_grossman)
d18_calc=psm_pseudoproxy.frac(d18.sel(site=calc_sites),tsurf.sel(site=calc_sites),psm_pseudoproxy.alpha_calc_trem)
d18_arag=psm_pseudoproxy.frac(d18.sel(site=arag_sites),tsurf.sel(site=arag_sites),psm_pseudoproxy.alpha_arag_grossman)
else:
print('unknown fractionation_temp-mode')
break
d18.loc[dict(site=calc_sites)]=d18_calc
d18.loc[dict(site=arag_sites)]=d18_arag
#if no fractionation deconvert pdb! (Else the value is really off)
else:
d18=psm_pseudoproxy.pdb_conversion_r(d18)
HXfull=d18
if pp_r is not None:
pp_r_list.append(pp_r)
elif psm=='icecore':
print('USING ICECORE PSM')
#weighting
if c.icecore['weighting']=='prec':
print('>>>>>>>>>>>>>APPLYING PREC WEIGHTING')
d18=psm_pseudoproxy.prec_weighting(prior_raw['d18O'],prior_raw['prec'],slice_unfinished=True,check_nan=c.check_nan)
else:
d18=prior['d18O']
d18=psm_pseudoproxy.obs_from_model(d18,lat=lats,lon=lons,interpol=c.interpol)
#add noise! (other model option prevent that this is applied to HXf)
if c.ppe['use']==True:
if c.ppe['noise_bf_filt']==True and other_model==True:
print('Noise added to d18O before filtering')
d18,pp_r=psm_pseudoproxy.pseudoproxies(d18, SNR=c.ppe['SNR'][0],noisetype=c.ppe['noise_type'],seed=c.seed)
#height correction
if c.icecore['height']==True:
print('>>>>>>>>>>>>>APPLYING HEIGHT CORRECTION')
oro=xr.open_dataset(c.oro)['oro']
oro=psm_pseudoproxy.obs_from_model(oro,lat=lats,lon=lons,interpol=c.interpol)
#oro.sel(lat=lats,lon=lons,method='nearest')
elev=proxies['elev']
z=(elev-oro)
#Tsurf: -0.65 https://en.wikipedia.org/wiki/Lapse_rate
#d18O take global value: global average -0.28: https://www.ajsonline.org/content/ajs/301/1/1.full.pdf
#d18= d18 + proxies['lapse_rate']/100*z
d18= d18 + -0.28/100*z
#Diffusion and compactation
if c.icecore['filter']==True:
print('>>>>>>>>>>>>>APPLYING PRYSM ICECORE FILTER')
prec_site=psm_pseudoproxy.obs_from_model(prior['prec'],lat=lats,lon=lons,interpol=c.interpol)
prec_site['site']=d18.site
tsurf_site=psm_pseudoproxy.obs_from_model(prior['tsurf'],lat=lats,lon=lons,interpol=c.interpol)
tsurf_site['site']=d18.site
#add noise before the filtering process for pseudoproxies, only acts on the d18O (see above)
for s in tqdm.tqdm(d18.site):
#only nproc=1 workso
d18.loc[dict(site=s)]=psm_pseudoproxy.ice_archive(d18.sel(site=s),prec_site.sel(site=s),tsurf_site.sel(site=s),xr.DataArray(np.array([101.325]*len(prior.time))),nproc=1)
HXfull=d18
if pp_r is not None:
pp_r_list.append(pp_r)
else:
raise Exception("Given psm type unknown. check 'psm' in config dictionary.")
#add site coordinate
HXfull['site']=proxies['site']
HXfull_all.append(HXfull)
#pseudoproxy error is eventually returned (else its just none and nothing happens)
if other_model==True:
return HXfull_all,pp_r_list
else:
return HXfull_all
def resample_wrapper(c,pp_y_all,pp_r_all):
#Suppres warnings. Bad practice, but warnings in resampling part are annoying (some pandas stuff)
warnings.simplefilter("ignore",category=DeprecationWarning)
warnings.simplefilter("ignore",category=FutureWarning)
#time arrays for each resolution
length=int(c.proxy_time[1])-int(c.proxy_time[0])
#workaround if we want to reconstruct only x-year means (speleo experiments)
if 1 not in c.timescales:
times_list=[xr.DataArray(xr.cftime_range(start=c.proxy_time[0],periods=(length//i+1),freq=str(i)+'YS',calendar='365_day'),dims='time') for i in np.concatenate([[1],c.timescales])]
else:
times_list=[xr.DataArray(xr.cftime_range(start=c.proxy_time[0],periods=(length//i+1),freq=str(i)+'YS',calendar='365_day'),dims='time') for i in c.timescales]
#adapt times_list (cut end in case it doesn't fit perfectly with largest block size)
#I needed an (eventually) different times_list for resampling the proxies
#this could definitely be nicer
new_times_list=[]
time_sc=c.timescales
if 1 not in time_sc:
time_sc=np.insert(time_sc,0,1)
for i,t in enumerate(times_list):
ts=time_sc[i]
end=str(((int(c.time[1])-int(c.time[0]))//ts)*ts+int(c.time[0]))
if end>c.proxy_time[1]:
end=str(((int(c.proxy_time[1])-int(c.time[0]))//ts)*ts+int(c.time[0]))
new_times_list.append(t.sel(time=slice(c.time[0],end)))
times_list=new_times_list
#drop where there are no values in the final time range (only without ppe)
first_time=times_list[0][0]
last_time=times_list[0][-1]
if c.ppe['use']==False:
pp_y_all_new=[]
pp_r_all_new=[]
for idx,pp in enumerate(pp_y_all):
pp_l=pp.sel(time=slice(first_time,last_time))
pp2=pp.copy(deep=True)
pp_r2=pp_r_all[idx].copy(deep=True)
for s in (pp.site):
avail_times=pp_l.sel(site=s).dropna('time').time.values
if len(avail_times)==0:
pp2=pp2.drop_sel(site=s.values)
pp_r2=pp_r2.drop_sel(site=s.values)