Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactors and bugfixes #1736

Merged
merged 13 commits into from
Oct 28, 2024
12 changes: 5 additions & 7 deletions docs/examples/example_globcurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,23 +219,21 @@ def test__particles_init_time():
assert pset[0].time - pset4[0].time == 0


@pytest.mark.xfail(reason="Time extrapolation error expected to be thrown", strict=True)
@pytest.mark.parametrize("mode", ["scipy", "jit"])
@pytest.mark.parametrize("use_xarray", [True, False])
def test_globcurrent_time_extrapolation_error(mode, use_xarray):
fieldset = set_globcurrent_fieldset(use_xarray=use_xarray)

pset = parcels.ParticleSet(
fieldset,
pclass=ptype[mode],
lon=[25],
lat=[-35],
time=fieldset.U.time[0] - timedelta(days=1).total_seconds(),
)

pset.execute(
parcels.AdvectionRK4, runtime=timedelta(days=1), dt=timedelta(minutes=5)
time=fieldset.U.grid.time[0] - timedelta(days=1).total_seconds(),
)
with pytest.raises(parcels.TimeExtrapolationError):
pset.execute(
parcels.AdvectionRK4, runtime=timedelta(days=1), dt=timedelta(minutes=5)
)


@pytest.mark.parametrize("mode", ["scipy", "jit"])
Expand Down
43 changes: 26 additions & 17 deletions parcels/field.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import collections
import datetime
import math
import warnings
from collections.abc import Iterable
from ctypes import POINTER, Structure, c_float, c_int, pointer
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

import dask.array as da
import numpy as np
Expand All @@ -21,7 +20,7 @@
assert_valid_gridindexingtype,
assert_valid_interp_method,
)
from parcels.tools._helpers import deprecated_made_private
from parcels.tools._helpers import deprecated_made_private, timedelta_to_float
from parcels.tools.converters import (
Geographic,
GeographicPolar,
Expand Down Expand Up @@ -150,6 +149,8 @@
* `Nested Fields <../examples/tutorial_NestedFields.ipynb>`__
"""

_cast_data_dtype: type[np.float32] | type[np.float64]

def __init__(
self,
name: str | tuple[str, str],
Expand All @@ -162,16 +163,16 @@
mesh: Mesh = "flat",
timestamps=None,
fieldtype=None,
transpose=False,
vmin=None,
vmax=None,
cast_data_dtype="float32",
time_origin=None,
transpose: bool = False,
vmin: float | None = None,
vmax: float | None = None,
cast_data_dtype: type[np.float32] | type[np.float64] | Literal["float32", "float64"] = "float32",
time_origin: TimeConverter | None = None,
interp_method: InterpMethod = "linear",
allow_time_extrapolation: bool | None = None,
time_periodic: TimePeriodic = False,
gridindexingtype: GridIndexingType = "nemo",
to_write=False,
to_write: bool = False,
**kwargs,
):
if kwargs.get("netcdf_decodewarning") is not None:
Expand Down Expand Up @@ -247,8 +248,8 @@
"Unsupported time_periodic=True. time_periodic must now be either False or the length of the period (either float in seconds or datetime.timedelta object."
)
if self.time_periodic is not False:
if isinstance(self.time_periodic, datetime.timedelta):
self.time_periodic = self.time_periodic.total_seconds()
self.time_periodic = timedelta_to_float(self.time_periodic)

if not np.isclose(self.grid.time[-1] - self.grid.time[0], self.time_periodic):
if self.grid.time[-1] - self.grid.time[0] > self.time_periodic:
raise ValueError("Time series provided is longer than the time_periodic parameter")
Expand All @@ -258,11 +259,19 @@

self.vmin = vmin
self.vmax = vmax
self._cast_data_dtype = cast_data_dtype
if self.cast_data_dtype == "float32":
self._cast_data_dtype = np.float32
elif self.cast_data_dtype == "float64":
self._cast_data_dtype = np.float64

match cast_data_dtype:
case "float32":
self._cast_data_dtype = np.float32
case "float64":
self._cast_data_dtype = np.float64
case _:

Check warning on line 268 in parcels/field.py

View check run for this annotation

Codecov / codecov/patch

parcels/field.py#L268

Added line #L268 was not covered by tests
self._cast_data_dtype = cast_data_dtype

if self.cast_data_dtype not in [np.float32, np.float64]:
raise ValueError(
f"Unsupported cast_data_dtype {self.cast_data_dtype!r}. Choose either: 'float32' or 'float64'"

Check warning on line 273 in parcels/field.py

View check run for this annotation

Codecov / codecov/patch

parcels/field.py#L273

Added line #L273 was not covered by tests
)

if not self.grid.defer_load:
self.data = self._reshape(self.data, transpose)
Expand Down Expand Up @@ -797,7 +806,7 @@
lat = da[dimensions["lat"]].values

time_origin = TimeConverter(time[0])
time = time_origin.reltime(time)
time = time_origin.reltime(time) # type: ignore[assignment]

grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
kwargs["time_periodic"] = time_periodic
Expand Down
6 changes: 3 additions & 3 deletions parcels/fieldfilebuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@
self.chunk_mapping = None

@classmethod
def add_to_dimension_name_map_global(self, name_map):
def add_to_dimension_name_map_global(cls, name_map):
"""
[externally callable]
This function adds entries to the name map from parcels_dim -> netcdf_dim. This is required if you want to
Expand All @@ -406,9 +406,9 @@
for pcls_dim_name in name_map.keys():
if isinstance(name_map[pcls_dim_name], list):
for nc_dim_name in name_map[pcls_dim_name]:
self._static_name_maps[pcls_dim_name].append(nc_dim_name)
cls._static_name_maps[pcls_dim_name].append(nc_dim_name)

Check warning on line 409 in parcels/fieldfilebuffer.py

View check run for this annotation

Codecov / codecov/patch

parcels/fieldfilebuffer.py#L409

Added line #L409 was not covered by tests
elif isinstance(name_map[pcls_dim_name], str):
self._static_name_maps[pcls_dim_name].append(name_map[pcls_dim_name])
cls._static_name_maps[pcls_dim_name].append(name_map[pcls_dim_name])

def add_to_dimension_name_map(self, name_map):
"""
Expand Down
4 changes: 2 additions & 2 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@

@classmethod
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def parse_wildcards(self, *args, **kwargs):
return self._parse_wildcards(*args, **kwargs)
def parse_wildcards(cls, *args, **kwargs):
return cls._parse_wildcards(*args, **kwargs)

Check warning on line 348 in parcels/fieldset.py

View check run for this annotation

Codecov / codecov/patch

parcels/fieldset.py#L348

Added line #L348 was not covered by tests

@classmethod
def _parse_wildcards(cls, paths, filenames, var):
Expand Down
9 changes: 5 additions & 4 deletions parcels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def __init__(
self.funccode = funccode
self.py_ast = py_ast
self.dyn_srcs = []
self.static_srcs = []
self.src_file = None
self.lib_file = None
self.log_file = None
Expand Down Expand Up @@ -562,9 +561,11 @@ def from_list(cls, fieldset, ptype, pyfunc_list, *args, **kwargs):
def cleanup_remove_files(lib_file, all_files_array, delete_cfiles):
if lib_file is not None:
if os.path.isfile(lib_file): # and delete_cfiles
[os.remove(s) for s in [lib_file] if os.path is not None and os.path.exists(s)]
if delete_cfiles and len(all_files_array) > 0:
[os.remove(s) for s in all_files_array if os.path is not None and os.path.exists(s)]
os.remove(lib_file)
if delete_cfiles:
for s in all_files_array:
if os.path.exists(s):
os.remove(s)
Comment on lines -565 to +568
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we need to check for os.path is not None here anymore? Was that old, redundant code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

os.path is a module of the os package, so it is never None.
A bit ago we had

from os import path

...
[... if path is not None ...]

so it wasn't imediately evident before


@staticmethod
def cleanup_unload_lib(lib):
Expand Down
6 changes: 3 additions & 3 deletions parcels/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,13 @@

def __repr__(self):
time_string = "not_yet_set" if self.time is None or np.isnan(self.time) else f"{self.time:f}"
str = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth)
p_string = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth)

Check warning on line 204 in parcels/particle.py

View check run for this annotation

Codecov / codecov/patch

parcels/particle.py#L204

Added line #L204 was not covered by tests
for var in vars(type(self)):
if var in ["lon_nextloop", "lat_nextloop", "depth_nextloop", "time_nextloop"]:
continue
if type(getattr(type(self), var)) is Variable and getattr(type(self), var).to_write is True:
str += f"{var}={getattr(self, var):f}, "
return str + f"time={time_string})"
p_string += f"{var}={getattr(self, var):f}, "
return p_string + f"time={time_string})"

Check warning on line 210 in parcels/particle.py

View check run for this annotation

Codecov / codecov/patch

parcels/particle.py#L209-L210

Added lines #L209 - L210 were not covered by tests

@classmethod
def add_variable(cls, var, *args, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions parcels/particledata.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def getPType(self):

def __repr__(self):
time_string = "not_yet_set" if self.time is None or np.isnan(self.time) else f"{self.time:f}"
str = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth)
p_string = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth)
for var in self._pcoll.ptype.variables:
if var.name in [
"lon_nextloop",
Expand All @@ -470,8 +470,8 @@ def __repr__(self):
]: # TODO check if time_nextloop is needed (or can work with time-dt?)
continue
if var.to_write is not False and var.name not in ["id", "lon", "lat", "depth", "time"]:
str += f"{var.name}={getattr(self, var.name):f}, "
return str + f"time={time_string})"
p_string += f"{var.name}={getattr(self, var.name):f}, "
return p_string + f"time={time_string})"

def delete(self):
"""Signal the particle for deletion."""
Expand Down
18 changes: 9 additions & 9 deletions parcels/particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import parcels
from parcels._compat import MPI
from parcels.tools._helpers import deprecated, deprecated_made_private
from parcels.tools._helpers import deprecated, deprecated_made_private, timedelta_to_float
from parcels.tools.warnings import FileWarning

__all__ = ["ParticleFile"]
Expand Down Expand Up @@ -48,7 +48,7 @@
"""

def __init__(self, name, particleset, outputdt=np.inf, chunks=None, create_new_zarrfile=True):
self._outputdt = outputdt.total_seconds() if isinstance(outputdt, timedelta) else outputdt
self._outputdt = timedelta_to_float(outputdt)
self._chunks = chunks
self._particleset = particleset
self._parcels_mesh = "spherical"
Expand Down Expand Up @@ -253,7 +253,7 @@
Z.append(a, axis=axis)
zarr.consolidate_metadata(store)

def write(self, pset, time, indices=None):
def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=None):
"""Write all data from one time step to the zarr file,
before the particle locations are updated.

Expand All @@ -264,7 +264,7 @@
time :
Time at which to write ParticleSet
"""
time = time.total_seconds() if isinstance(time, timedelta) else time
time = timedelta_to_float(time) if time is not None else None

if pset.particledata._ncount == 0:
warnings.warn(
Expand Down Expand Up @@ -295,18 +295,18 @@
if self.create_new_zarrfile:
if self.chunks is None:
self._chunks = (len(ids), 1)
if pset._repeatpclass is not None and self.chunks[0] < 1e4:
if pset._repeatpclass is not None and self.chunks[0] < 1e4: # type: ignore[index]
warnings.warn(
f"ParticleFile chunks are set to {self.chunks}, but this may lead to "
f"a significant slowdown in Parcels when many calls to repeatdt. "
f"Consider setting a larger chunk size for your ParticleFile (e.g. chunks=(int(1e4), 1)).",
FileWarning,
stacklevel=2,
)
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]):
arrsize = (self._maxids, self.chunks[1])
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): # type: ignore[index]
arrsize = (self._maxids, self.chunks[1]) # type: ignore[index]

Check warning on line 307 in parcels/particlefile.py

View check run for this annotation

Codecov / codecov/patch

parcels/particlefile.py#L307

Added line #L307 was not covered by tests
else:
arrsize = (len(ids), self.chunks[1])
arrsize = (len(ids), self.chunks[1]) # type: ignore[index]
ds = xr.Dataset(
attrs=self.metadata,
coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))},
Expand All @@ -331,7 +331,7 @@
data[ids, 0] = pset.particledata.getvardata(var, indices_to_write)
dims = ["trajectory", "obs"]
ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout])
ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks
ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks # type: ignore[index]
ds.to_zarr(self.fname, mode="w")
self._create_new_zarrfile = False
else:
Expand Down
37 changes: 20 additions & 17 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from parcels.particle import JITParticle, Variable
from parcels.particledata import ParticleData, ParticleDataIterator
from parcels.particlefile import ParticleFile
from parcels.tools._helpers import deprecated, deprecated_made_private
from parcels.tools._helpers import deprecated, deprecated_made_private, timedelta_to_float
from parcels.tools.converters import _get_cftime_calendars, convert_to_flat_array
from parcels.tools.global_statics import get_package_dir
from parcels.tools.loggers import logger
Expand Down Expand Up @@ -188,12 +188,13 @@
lon.size == kwargs[kwvar].size
), f"{kwvar} and positions (lon, lat, depth) don't have the same lengths."

self.repeatdt = repeatdt.total_seconds() if isinstance(repeatdt, timedelta) else repeatdt
self.repeatdt = timedelta_to_float(repeatdt) if repeatdt is not None else None

if self.repeatdt:
if self.repeatdt <= 0:
raise "Repeatdt should be > 0"
raise ValueError("Repeatdt should be > 0")

Check warning on line 195 in parcels/particleset.py

View check run for this annotation

Codecov / codecov/patch

parcels/particleset.py#L195

Added line #L195 was not covered by tests
if time[0] and not np.allclose(time, time[0]):
raise "All Particle.time should be the same when repeatdt is not None"
raise ValueError("All Particle.time should be the same when repeatdt is not None")

Check warning on line 197 in parcels/particleset.py

View check run for this annotation

Codecov / codecov/patch

parcels/particleset.py#L197

Added line #L197 was not covered by tests
self._repeatpclass = pclass
self._repeatkwargs = kwargs
self._repeatkwargs.pop("partition_function", None)
Expand Down Expand Up @@ -981,13 +982,13 @@
pyfunc=AdvectionRK4,
pyfunc_inter=None,
endtime=None,
runtime=None,
dt=1.0,
runtime: float | timedelta | np.timedelta64 | None = None,
dt: float | timedelta | np.timedelta64 = 1.0,
output_file=None,
verbose_progress=True,
postIterationCallbacks=None,
callbackdt=None,
delete_cfiles=True,
callbackdt: float | timedelta | np.timedelta64 | None = None,
delete_cfiles: bool = True,
):
"""Execute a given kernel function over the particle set for multiple timesteps.

Expand Down Expand Up @@ -1067,22 +1068,24 @@
if self.time_origin.calendar is None:
raise NotImplementedError("If fieldset.time_origin is not a date, execution endtime must be a double")
endtime = self.time_origin.reltime(endtime)
if isinstance(runtime, timedelta):
runtime = runtime.total_seconds()
if isinstance(dt, timedelta):
dt = dt.total_seconds()

if runtime is not None:
runtime = timedelta_to_float(runtime)

dt = timedelta_to_float(dt)

if abs(dt) <= 1e-6:
raise ValueError("Time step dt is too small")
if (dt * 1e6) % 1 != 0:
raise ValueError("Output interval should not have finer precision than 1e-6 s")
outputdt = output_file.outputdt if output_file else np.inf
if isinstance(outputdt, timedelta):
outputdt = outputdt.total_seconds()

outputdt = timedelta_to_float(output_file.outputdt) if output_file else np.inf

if outputdt is not None:
_warn_outputdt_release_desync(outputdt, self.particledata.data["time_nextloop"])

if isinstance(callbackdt, timedelta):
callbackdt = callbackdt.total_seconds()
if callbackdt is not None:
callbackdt = timedelta_to_float(callbackdt)

assert runtime is None or runtime >= 0, "runtime must be positive"
assert outputdt is None or outputdt >= 0, "outputdt must be positive"
Expand Down
12 changes: 12 additions & 0 deletions parcels/tools/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import functools
import warnings
from collections.abc import Callable
from datetime import timedelta

import numpy as np

PACKAGE = "Parcels"

Expand Down Expand Up @@ -56,3 +59,12 @@ def deprecated_made_private(func: Callable) -> Callable:

def patch_docstring(obj: Callable, extra: str) -> None:
obj.__doc__ = f"{obj.__doc__ or ''}{extra}".strip()


def timedelta_to_float(dt: float | timedelta | np.timedelta64) -> float:
"""Convert a timedelta to a float in seconds."""
if isinstance(dt, timedelta):
return dt.total_seconds()
if isinstance(dt, np.timedelta64):
return float(dt / np.timedelta64(1, "s"))
return float(dt)
Loading
Loading