Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jul 12, 2024
1 parent ddec77b commit 11485a4
Show file tree
Hide file tree
Showing 11 changed files with 26 additions and 39 deletions.
6 changes: 0 additions & 6 deletions deps/build.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@ struct DevitoException <: Exception
msg::String
end

if PyCall.pyversion >= VersionNumber("3.12.0")
install = ["install", "--user"]
else
install = ["install"]
end

pk = try
pyimport("pkg_resources")
catch e
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/modeling_basic_2D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#' This example is converted to a markdown file for the documentation.

#' # Import JUDI, Linear algebra utilities and Plotting
using JUDI, LinearAlgebra, SlimPlotting, SegyIO, SlimOptim
using JUDI, LinearAlgebra, SlimPlotting

#+ echo = false; results = "hidden"
close("all")
Expand Down
4 changes: 3 additions & 1 deletion src/JUDI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,12 @@ function _worker_pool()
return nothing
end
p = default_worker_pool()
pool = length(workers()) < 2 ? nothing : p
pool = nworkers(p) < 2 ? nothing : p
return pool
end

nworkers(::Any) = length(workers())

_TFuture = Future
_verbose = false
_devices = []
Expand Down
2 changes: 1 addition & 1 deletion src/TimeModeling/Modeling/distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ with different reduction functions.
function reduce!(futures::Vector{_TFuture})
isnothing(_worker_pool()) && return reduce_all_workers!(futures)
# Number of parallel workers
nwork = length(workers())
nwork = nworkers(_worker_pool())
nf = length(futures)
# Reduction batch. We want to avoid finished task to hang waiting for the
# binary tree reduction to reach their index holding memory.
Expand Down
14 changes: 0 additions & 14 deletions src/TimeModeling/Modeling/propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ the pool is empty, a standard loop and accumulation is ran. If the pool is a jul
any custom Distributed pool, the loop is distributed via `remotecall` followed by are binary tree remote reduction.
"""
function run_and_reduce(func, pool, nsrc, arg_func::Function; kw=nothing)
# Allocate devices
_set_devices!()
# Run distributed loop
res = Vector{_TFuture}(undef, nsrc)
for i = 1:nsrc
Expand Down Expand Up @@ -51,18 +49,6 @@ function run_and_reduce(func, ::Nothing, nsrc, arg_func::Function; kw=nothing)
out
end

function _set_devices!()
ndevices = length(_devices)
if ndevices < 2
return
end
asyncmap(enumerate(workers())) do (pi, p)
remotecall_wait(p) do
pyut.set_device_ids(_devices[pi % ndevices + 1])
end
end
end

_prop_fw(::judiPropagator{T, O}) where {T, O} = true
_prop_fw(::judiPropagator{T, :adjoint}) where T = false
_prop_fw(J::judiJacobian) = _prop_fw(J.F)
Expand Down
1 change: 1 addition & 0 deletions src/pysource/FD_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


trig_mapper = {cos.__sympy_class__: cos, sin.__sympy_class__: sin}

r2 = lambda x: rot_axis2(x).applyfunc(lambda i: trig_mapper.get(i.func, i.func)(*i.args))
r3 = lambda x: rot_axis3(x).applyfunc(lambda i: trig_mapper.get(i.func, i.func)(*i.args))

Expand Down
4 changes: 3 additions & 1 deletion src/pysource/geom_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

from devito.tools import as_tuple

from sources import *
Expand All @@ -13,7 +15,7 @@ def src_rec(model, u, src_coords=None, rec_coords=None, wavelet=None, nt=None):
else:
src = PointSource(name="src%s" % namef, grid=model.grid, ntime=nt,
coordinates=src_coords)
src.data[:] = wavelet[:] if wavelet is not None else 0.
src.data[:] = wavelet.view(np.ndarray) if wavelet is not None else 0.
rcv = None
if rec_coords is not None:
rcv = Receiver(name="rcv%s" % namef, grid=model.grid, ntime=nt,
Expand Down
16 changes: 14 additions & 2 deletions src/pysource/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

from sympy import finite_diff_weights as fd_w
from devito import (Grid, Function, SubDimension, Eq, Inc, switchconfig,
Operator, mmin, mmax, initialize_function,
Operator, mmin, mmax, initialize_function, MPI,
Abs, sqrt, sin, Constant, CustomDimension)

from devito.tools import as_tuple, memoized_func

try:
Expand Down Expand Up @@ -34,7 +35,18 @@ def _1d_cmax(self, vp, eps):
if eps is not None:
epsi = eps.data.max(axis=rdim)
vpi._local[:] *= np.sqrt(1. + 2.*epsi._local[:])
cmaxs.append(vpi)
# Gather on all ranks if distributed.
# Since we have a small-ish 1D vector we avoid the index gymnastic
# and create the full 1d vector on al ranks with the local values
# at the local indices and simply gather with Max
if vp.grid.distributor.is_parallel:
out = np.zeros(vp.grid.shape[di], dtype=vpi.dtype)
tmp = np.zeros(vp.grid.shape[di], dtype=vpi.dtype)
tmp[vp.local_indices[di]] = vpi._local
vp.grid.distributor.comm.Allreduce(tmp, out, op=MPI.MAX)
cmaxs.append(out)
else:
cmaxs.append(vpi)

return cmaxs

Expand Down
2 changes: 1 addition & 1 deletion src/pysource/propagators.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def forward_grad(model, src_coords, rcv_coords, wavelet, v,
q = extented_src(model, ws, wavelet, q=q)

# Set up PDE expression and rearrange
pde, extra = wave_kernel(model, u, q=q, f0=f0, )
pde, extra = wave_kernel(model, u, q=q, f0=f0)

# Setup source and receiver
rexpr = geom_expr(model, u, src_coords=src_coords, nt=nt,
Expand Down
2 changes: 1 addition & 1 deletion src/pysource/sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def inner_grad(u, v):
v: TimeFunction
Second field
"""
return grad(u).dot(grad(v))
return grad(u, shift=.5).dot(grad(v, shift=.5))


fwi_src = lambda *ar, **kw: isic_src(*ar, icsign=-1, **kw)
Expand Down
12 changes: 1 addition & 11 deletions src/pysource/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,21 +132,11 @@ def fields_kwargs(*args):
return kw


DEVICE = {"id": -1} # noqa


def set_device_ids(devid):
DEVICE["id"] = devid


def base_kwargs(dt):
"""
Most basic keyword arguments needed by the operator.
"""
if configuration['platform'].name == 'nvidiaX':
return {'dt': dt, 'deviceid': DEVICE["id"]}
else:
return {'dt': dt}
return {'dt': dt}


def cleanup_wf(u):
Expand Down

0 comments on commit 11485a4

Please sign in to comment.