diff --git a/examples/shallow_water.py b/examples/shallow_water.py index e1bc10d..bff56f2 100644 --- a/examples/shallow_water.py +++ b/examples/shallow_water.py @@ -209,14 +209,6 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly): bath = 1.0 return bath * create_full(T_shape, 1.0, dtype) - # inital elevation - u0, v0, e0 = exact_solution( - 0, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d - ) - e[:, :] = e0.to_device(device) - u[:, :] = u0.to_device(device) - v[:, :] = v0.to_device(device) - # set bathymetry # h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly).to_device(device) # steady state potential energy @@ -335,6 +327,18 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2): v[:, 1:-1] = v[:, 1:-1] / 3.0 + 2.0 / 3.0 * (v2[:, 1:-1] + dt * dvdt) e[:, :] = e[:, :] / 3.0 + 2.0 / 3.0 * (e2[:, :] + dt * dedt) + # warm jit cache + step(u, v, e, u1, v1, e1, u2, v2, e2) + sync() + + # initial solution + u0, v0, e0 = exact_solution( + 0, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d + ) + e[:, :] = e0.to_device(device) + u[:, :] = u0.to_device(device) + v[:, :] = v0.to_device(device) + t = 0 i_export = 0 next_t_export = 0