Skip to content

Commit

Permalink
Added codes for filtering and examples on Nstage FWI
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Mar 28, 2024
1 parent 04b12c3 commit 222983e
Show file tree
Hide file tree
Showing 8 changed files with 6,204 additions and 3,978 deletions.
117 changes: 117 additions & 0 deletions devitofwi/preproc/filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import numpy as np
import matplotlib.pyplot as plt

from scipy.signal import butter, sosfiltfilt, correlate, freqs, hilbert


def create_filter(nfilt, fmin, fmax, dt, plotflag=False):
if fmin is None:
b, a = butter(nfilt, fmax, 'low', analog=True)
sos = butter(nfilt, fmax, 'low', fs=1 / dt, output='sos')
else:
b, a = butter(nfilt, [fmin, fmax], 'bandpass', analog=True)
sos = butter(nfilt, [fmin, fmax], 'bandpass', fs=1 / dt, output='sos')

if plotflag:
w, h = freqs(b, a)
plt.semilogx(w, 20 * np.log10(abs(h)), 'k', lw=2)
plt.title('Butterworth filter frequency response')
plt.xlabel('Frequency [radians / second]')
plt.ylabel('Amplitude [dB]')
plt.margins(0, 0.1)
plt.grid(which='both', axis='both')
plt.axvline(fmax, color='green') # cutoff frequency

return b, a, sos


def apply_filter(sos, inp):
filtered = sosfiltfilt(sos, inp, axis=-1)
return filtered


def filter_data(nfilt, fmin, fmax, dt, inp, plotflag=False):
"""Filter data
Apply Butterworth band-pass filter to data
Parameters
----------
nfilt : :obj:`int`
Size of filter
fmin : :obj:`float`
Minimum frequency
fmax : :obj:`float`
Maximum frequency
dt : :obj:`float`
Time sampling
inp : :obj:`numpy.ndarray`
Data of size `nx x nt`
Returns
-------
b : :obj:`numpy.ndarray`
Filter numerator coefficients
b : :obj:`numpy.ndarray`
Filter denominator coefficients
sos : :obj:`numpy.ndarray`
Filter sos
filtered : :obj:`numpy.ndarray`
Filtered data of size `nx x nt`
"""
b, a, sos = create_filter(nfilt, fmin, fmax, dt, plotflag=plotflag)
filtered = apply_filter(sos, inp)

return b, a, sos, filtered


class Filter():
"""Filtering
Define a sequence of filters to apply to a dataset/wavelet
Parameters
----------
freqs : :obj:`list`
Minimum frequencies
nfilt : :obj:`int`
Size of filters
dt : :obj:`float`
Time sampling
p
"""
def __init__(self, freqs, nfilts, dt, plotflag=False):
self.freqs = freqs
self.nfilts = nfilts
self.dt = dt
self.plotflag = plotflag
self.filters = self._create_filters()

def _create_filters(self):
filters = []

for freq, nfilt in zip(self.freqs, self.nfilts):
filters.append(create_filter(nfilt, None, freq, self.dt, plotflag=self.plotflag)[-1])
return filters

def apply_filter(self, inp, ifilt=0):
return apply_filter(self.filters[ifilt], inp)

def find_optimal_t0(self, inp, pad=400, thresh=1e-2):
"""Find optimal padding
Identify optimal padding to avoid any filtered signal to become acausal. To be used when designing the filters
to choose how much the wavelet and observed data must be padded
"""
inppad = np.pad(inp, (pad, pad))
itmax = np.argmax(np.abs(inppad))
it0 = np.where(np.abs(inppad[:itmax]) < thresh * inppad[itmax])[0][-1]
for ifilt in range(len(self.filters)):
inpfilt = apply_filter(self.filters[ifilt], inppad)
inpfiltenv = np.abs(hilbert(inpfilt))
it0filt = np.where(np.abs(inpfiltenv[:itmax]) < thresh * inpfiltenv[itmax])[0][-1]
it0 = min(it0, it0filt)
optimalpad = pad - it0
return optimalpad
71 changes: 71 additions & 0 deletions devitofwi/source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import matplotlib.pyplot as plt

from examples.seismic.utils import sources, PointSource


class CustomSource(PointSource):

"""
Abstract base class for symbolic objects that encapsulates a set of
sources with a user defined source signal wavelet.
Parameters
----------
name : str
Name for the resulting symbol.
grid : Grid
The computational domain.
time_range : TimeAxis
TimeAxis(start, step, num) object.
wav : numpy.ndarray
Wavelet
"""

__rkwargs__ = PointSource.__rkwargs__ + ['wav']

@classmethod
def __args_setup__(cls, *args, **kwargs):
kwargs.setdefault('npoint', 1)

return super().__args_setup__(*args, **kwargs)

def __init_finalize__(self, *args, **kwargs):
super().__init_finalize__(*args, **kwargs)

self.wav = kwargs.get('wav')

if not self.alias:
for p in range(kwargs['npoint']):
self.data[:, p] = self.wavelet

@property
def wavelet(self):
"""
Return user-provided wavelet
"""
return self.wav

def show(self, idx=0, wavelet=None):
"""
Plot the wavelet of the specified source.
Parameters
----------
idx : int
Index of the source point for which to plot wavelet.
wavelet : ndarray or callable
Prescribed wavelet instead of one from this symbol.
"""
wavelet = wavelet or self.data[:, idx]
plt.figure()
plt.plot(self.time_values, wavelet)
plt.xlabel('Time (ms)')
plt.ylabel('Amplitude')
plt.tick_params()
plt.show()


sources['CustomSource'] = CustomSource

14 changes: 14 additions & 0 deletions devitofwi/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from os import listdir, path
from shutil import rmtree
from tempfile import gettempdir


def clear_devito_cache():
tempdir = gettempdir()
for i in listdir(tempdir):
if i.startswith('devito-'):
try:
target = path.join(tempdir, i)
rmtree(target)
except:
pass
55 changes: 36 additions & 19 deletions devitofwi/waveengine/acoustic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray, SamplingLike
from tqdm.notebook import tqdm

devito_message = deps.devito_import("the twoway module")

if devito_message is None:
from devito import Function
from examples.seismic import AcquisitionGeometry, Model, Receiver
from examples.seismic.acoustic import AcousticWaveSolver

from devito import Function
from examples.seismic import AcquisitionGeometry, Model, Receiver
from examples.seismic.acoustic import AcousticWaveSolver
from devitofwi.source import CustomSource

#class AcousticWave2D(LinearOperator):
class AcousticWave2D():
Expand Down Expand Up @@ -49,14 +47,16 @@ class AcousticWave2D():
(use``None`` if the data is already available)
vpinit : :obj:`numpy.ndarray`, optional
Initial velocity model in m/s as starting guess for inversion
src_type : :obj:`str`, optional
Source type
space_order : :obj:`int`, optional
Spatial ordering of FD stencil
nbl : :obj:`int`, optional
Number ordering of samples in absorbing boundaries
src_type : :obj:`str`, optional
Source type
f0 : :obj:`float`, optional
Source peak frequency (Hz)
wav : :obj:`numpy.ndarray`, optional
Wavelet (if provided ``src_type`` and ``f0`` will be ignored
checkpointing : :obj:`bool`, optional
Use checkpointing (``True``) or not (``False``). Note that
using checkpointing is needed when dealing with large models
Expand All @@ -80,18 +80,16 @@ def __init__(
tn: int,
vp: Optional[NDArray] = None,
vpinit: Optional[NDArray] = None,
src_type: Optional[str] = "Ricker",
space_order: Optional[int] = 4,
nbl: Optional[int] = 20,
src_type: Optional[str] = "Ricker",
f0: Optional[float] = 20.0,
wav: Optional[NDArray] = None,
checkpointing: Optional[bool] = False,
loss: Optional[Type] = None,
dtype: Optional[DTypeLike] = "float32",
) -> None:
if devito_message is not None:
raise NotImplementedError(devito_message)

# checks
# velocity checks to ensure either vp or vint are provided
if vp is None and vpinit is None:
raise ValueError("Either vp or vpinit must be provided...")
if vpinit is not None and loss is None:
Expand All @@ -101,6 +99,7 @@ def __init__(
self.space_order = space_order
self.nbl = nbl
self.checkpointing = checkpointing
self.wav = wav

# inversion parameters
self.loss = loss
Expand Down Expand Up @@ -248,15 +247,24 @@ def _mod_oneshot(self, isrc: int, dt: float = None) -> NDArray:
# update source location in geometry
geometry = self.geometry1shot
geometry.src_positions[0, :] = self.geometry.src_positions[isrc, :]


# re-create source (if wav is not None)
if self.wav is None:
src = geometry.src
else:
src = CustomSource(name='src', grid=self.model.grid,
wav=self.wav, npoint=1,
time_range=geometry.time_axis)
src.coordinates.data[0, :] = self.geometry.src_positions[isrc, :]

# data object
d = Receiver(name='data', grid=self.model.grid,
time_range=geometry.time_axis,
coordinates=geometry.rec_positions)
# solve
solver = AcousticWaveSolver(self.model, geometry,
space_order=self.space_order)
_, _, _ = solver.forward(vp=self.model.vp, rec=d)
_, _, _ = solver.forward(vp=self.model.vp, rec=d, src=src)

# resample
if dt is None:
Expand All @@ -283,14 +291,14 @@ def _mod_allshots(self, dt=None) -> NDArray:
dtot = np.array(dtot).reshape(nsrc, d.shape[0], d.shape[1])
return dtot

def _loss_grad_oneshot(self, vp, geometry, solver, d_obs, d_syn, adjsrc, grad, dobs, isrc,
def _loss_grad_oneshot(self, vp, geometry, src, solver, d_obs, d_syn, adjsrc, grad, dobs, isrc,
computeloss=True, computegrad=True) -> Tuple[float, NDArray]:

# Generate synthetic data from true model
#d_obs.data[:] = dobs

# Compute smooth data and full forward wavefield u0
_, u0, _ = solver.forward(vp=vp, save=True, rec=d_syn)
_, u0, _ = solver.forward(vp=vp, save=True, rec=d_syn, src=src)

# Compute loss
if computeloss:
Expand Down Expand Up @@ -330,7 +338,15 @@ def _loss_grad(self, vp, dobs, isrcs=None, mask=None, computeloss=True, computeg
"""
# geometry for single source
geometry = self.geometry1shot


# re-create source (if wav is not None)
if self.wav is None:
src = geometry.src
else:
src = CustomSource(name='src', grid=self.model.grid,
wav=self.wav, npoint=1,
time_range=geometry.time_axis)

# solver
solver = AcousticWaveSolver(self.model, geometry,
space_order=self.space_order)
Expand All @@ -354,7 +370,8 @@ def _loss_grad(self, vp, dobs, isrcs=None, mask=None, computeloss=True, computeg
for isrc in tqdm(isrcs):
# update source location in geometry
geometry.src_positions[0, :] = self.geometry.src_positions[isrc, :]
lossgrad = self._loss_grad_oneshot(vp, geometry, solver, d_obs, d_syn, adjsrc, grad, dobs[isrc], isrc)
src.coordinates.data[0, :] = self.geometry.src_positions[isrc, :]
lossgrad = self._loss_grad_oneshot(vp, geometry, src, solver, d_obs, d_syn, adjsrc, grad, dobs[isrc], isrc)
if computeloss and computegrad:
loss_isrc, grad = lossgrad
loss += loss_isrc
Expand Down
1,011 changes: 541 additions & 470 deletions notebooks/acoustic/AcousticVel_L2_1stage.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 222983e

Please sign in to comment.