Skip to content

Commit

Permalink
reduce wavefield buffer to minimal
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jul 15, 2024
1 parent 11485a4 commit 2158f99
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 16 deletions.
24 changes: 15 additions & 9 deletions src/pysource/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from devito import (TimeFunction, ConditionalDimension, Function,
DefaultDimension, Dimension, VectorTimeFunction,
TensorTimeFunction)
TensorTimeFunction, Buffer)
from devito.builtins import initialize_function
from devito.tools import as_tuple

Expand All @@ -14,7 +14,8 @@
from utils import compression_mode


def wavefield(model, space_order, save=False, nt=None, fw=True, name='', t_sub=1):
def wavefield(model, space_order, save=False, nt=None, fw=True, name='', t_sub=1,
tfull=False):
"""
Create the wavefield for the wave equation
Expand All @@ -33,24 +34,28 @@ def wavefield(model, space_order, save=False, nt=None, fw=True, name='', t_sub=1
Forward or backward (for naming)
name: string
Custom name attached to default (u+name)
tfull: Bool
Whether need full buffer for e.g. second time derivative
"""
name = "u"+name if fw else "v"+name
save = False if t_sub > 1 else save
nsave = Buffer(3 if tfull else 2) if not save else nt

if model.is_tti:
u = TimeFunction(name="%s1" % name, grid=model.grid, time_order=2,
space_order=space_order, save=None if not save else nt)
space_order=space_order, save=nsave)
v = TimeFunction(name="%s2" % name, grid=model.grid, time_order=2,
space_order=space_order, save=None if not save else nt)
space_order=space_order, save=nsave)
return (u, v)
elif model.is_elastic:
v = VectorTimeFunction(name="v", grid=model.grid, time_order=1,
space_order=space_order, save=None)
space_order=space_order, save=Buffer(1))
tau = TensorTimeFunction(name="tau", grid=model.grid, time_order=1,
space_order=space_order, save=None)
space_order=space_order, save=Buffer(1))
return (v, tau)
else:
return TimeFunction(name=name, grid=model.grid, time_order=2,
space_order=space_order, save=None if not save else nt)
space_order=space_order, save=nsave)


def forward_wavefield(model, space_order, save=True, nt=10, dft=False, t_sub=1, fw=True):
Expand Down Expand Up @@ -112,7 +117,7 @@ def memory_field(p):
Forward wavefield
"""
return TimeFunction(name='r%s' % p.name, grid=p.grid, time_order=2,
space_order=p.space_order, save=None)
space_order=p.space_order, save=Buffer(2))


def wavefield_subsampled(model, u, nt, t_sub, space_order=8):
Expand Down Expand Up @@ -171,7 +176,8 @@ def lr_src_fields(model, weight, wavelet, empty_w=False, rec=False):
time = model.grid.time_dim
nt = wavelet.shape[0]
wn = 'rec' if rec else 'src'
wavelett = Function(name='wf_%s' % wn, dimensions=(time,), shape=(nt,))
wavelett = TimeFunction(name='wf_%s' % wn, dimensions=(time,), time_dim=time,
shape=(nt,), save=nt, grid=model.grid)
wavelett.data[:] = np.array(wavelet)[:, 0]
if empty_w:
source_weight = Function(name='%s_weight' % wn, grid=model.grid, space_order=0)
Expand Down
4 changes: 1 addition & 3 deletions src/pysource/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def tti_kernel(model, u1, u2, fw=True, q=None):
u1 : TimeFunction
First component (pseudo-P) of the wavefield
u2 : TimeFunction
First component (pseudo-P) of the wavefield
Second component (pseudo-S) of the wavefield
fw: Bool
Whether forward or backward in time propagation
q : TimeFunction or Expr
Expand Down Expand Up @@ -190,8 +190,6 @@ def elastic_kernel(model, v, tau, fw=True, q=None):
q : TimeFunction or Expr
Full time-space source as a tuple (one value for each component)
"""
if 'nofsdomain' in model.grid.subdomains:
raise NotImplementedError("Free surface not supported for elastic modelling")
if not fw:
raise NotImplementedError("Only forward modeling for the elastic equation")

Expand Down
9 changes: 6 additions & 3 deletions src/pysource/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,15 @@ class Model(object):
"""
def __init__(self, origin, spacing, shape, space_order=8, nbl=40, dtype=np.float32,
m=None, epsilon=None, delta=None, theta=None, phi=None, rho=None,
b=None, qp=None, lam=None, mu=None, dm=None, fs=False, abox=True,
b=None, qp=None, lam=None, mu=None, dm=None, fs=False,
**kwargs):
# Setup devito grid
self.shape = tuple(shape)
self.nbl = int(nbl)
self.origin = tuple([dtype(o) for o in origin])
abc_type = "mask" if (qp is not None or mu is not None) else "damp"
self.fs = fs
self._abox = abox
self._abox = None
# Origin of the computational domain with boundary to inject/interpolate
# at the correct index
origin_pml = [dtype(o - s*nbl) for o, s in zip(origin, spacing)]
Expand Down Expand Up @@ -546,7 +546,7 @@ def spacing_map(self):
return sp_map

def abox(self, src, rec, fw=True):
if ABox is None:
if ABox is None or (src is None and rec is None):
return {}
if not fw:
src, rec = rec, src
Expand Down Expand Up @@ -638,6 +638,9 @@ def zero_thomsen(self):
def __init_abox__(self, src, rec, fw=True):
if ABox is None:
return
if src is None and rec is None:
self._abox = None
return
eps = getattr(self, 'epsilon', None)
if not fw:
src, rec = rec, src
Expand Down
2 changes: 1 addition & 1 deletion src/pysource/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def born_op(p_params, tti, visco, elas, space_order, fw, spacing, save, pt_src,
f0 = Constant('f0')

# Setting wavefield
u = wavefield(model, space_order, save=save, nt=nt, t_sub=t_sub, fw=fw)
u = wavefield(model, space_order, save=save, nt=nt, t_sub=t_sub, fw=fw, tfull=True)
ul = wavefield(model, space_order, name="l", fw=fw)

# Setup source and receiver
Expand Down

0 comments on commit 2158f99

Please sign in to comment.