Skip to content

Commit

Permalink
Add type ignore comments
Browse files Browse the repository at this point in the history
Postpone typing to a future refactoring (particularly for reltime)
  • Loading branch information
VeckoTheGecko committed Oct 28, 2024
1 parent aef832a commit b351559
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ def from_xarray(
lat = da[dimensions["lat"]].values

time_origin = TimeConverter(time[0])
time = time_origin.reltime(time)
time = time_origin.reltime(time) # type: ignore[assignment]

grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
kwargs["time_periodic"] = time_periodic
Expand Down
10 changes: 5 additions & 5 deletions parcels/particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,18 +295,18 @@ def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=N
if self.create_new_zarrfile:
if self.chunks is None:
self._chunks = (len(ids), 1)
if pset._repeatpclass is not None and self.chunks[0] < 1e4:
if pset._repeatpclass is not None and self.chunks[0] < 1e4: # type: ignore[index]
warnings.warn(
f"ParticleFile chunks are set to {self.chunks}, but this may lead to "
f"a significant slowdown in Parcels when many calls to repeatdt. "
f"Consider setting a larger chunk size for your ParticleFile (e.g. chunks=(int(1e4), 1)).",
FileWarning,
stacklevel=2,
)
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]):
arrsize = (self._maxids, self.chunks[1])
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): # type: ignore[index]
arrsize = (self._maxids, self.chunks[1]) # type: ignore[index]
else:
arrsize = (len(ids), self.chunks[1])
arrsize = (len(ids), self.chunks[1]) # type: ignore[index]
ds = xr.Dataset(
attrs=self.metadata,
coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))},
Expand All @@ -331,7 +331,7 @@ def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=N
data[ids, 0] = pset.particledata.getvardata(var, indices_to_write)
dims = ["trajectory", "obs"]
ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout])
ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks
ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks # type: ignore[index]
ds.to_zarr(self.fname, mode="w")
self._create_new_zarrfile = False
else:
Expand Down
10 changes: 5 additions & 5 deletions parcels/tools/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, time_origin: float | np.datetime64 | np.timedelta64 | cftime.
elif isinstance(time_origin, cftime.datetime):
self.calendar = time_origin.calendar

def reltime(self, time: TimeConverter | np.datetime64 | np.timedelta64 | cftime.datetime) -> float:
def reltime(self, time: TimeConverter | np.datetime64 | np.timedelta64 | cftime.datetime) -> float | npt.NDArray:
"""Method to compute the difference, in seconds, between a time and the time_origin
of the TimeConverter
Expand All @@ -80,26 +80,26 @@ def reltime(self, time: TimeConverter | np.datetime64 | np.timedelta64 | cftime.
"""
time = time.time_origin if isinstance(time, TimeConverter) else time
if self.calendar in ["np_datetime64", "np_timedelta64"]:
return (time - self.time_origin) / np.timedelta64(1, "s")
return (time - self.time_origin) / np.timedelta64(1, "s") # type: ignore
elif self.calendar in _get_cftime_calendars():
if isinstance(time, (list, np.ndarray)):
try:
return np.array([(t - self.time_origin).total_seconds() for t in time])
return np.array([(t - self.time_origin).total_seconds() for t in time]) # type: ignore
except ValueError:
raise ValueError(
f"Cannot subtract 'time' (a {type(time)} object) from a {self.calendar} calendar.\n"
f"Provide 'time' as a {type(self.time_origin)} object?"
)
else:
try:
return (time - self.time_origin).total_seconds()
return (time - self.time_origin).total_seconds() # type: ignore
except ValueError:
raise ValueError(
f"Cannot subtract 'time' (a {type(time)} object) from a {self.calendar} calendar.\n"
f"Provide 'time' as a {type(self.time_origin)} object?"
)
elif self.calendar is None:
return time - self.time_origin
return time - self.time_origin # type: ignore
else:
raise RuntimeError(f"Calendar {self.calendar} not implemented in TimeConverter")

Expand Down

0 comments on commit b351559

Please sign in to comment.