Skip to content

Commit

Permalink
🚨 Update type hints and import statements
Browse files Browse the repository at this point in the history
  • Loading branch information
fbriol committed Feb 26, 2024
1 parent 5792f00 commit 6a5d488
Show file tree
Hide file tree
Showing 24 changed files with 920 additions and 752 deletions.
12 changes: 6 additions & 6 deletions examples/ex_geodetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@
# %%
# It is possible to do the same calculation on a large number of coordinates
# quickly.
lon = numpy.arange(0, 360, 10)
lat = numpy.arange(-90, 90.5, 10)
lon = numpy.arange(0, 360, 10, dtype=numpy.float64)
lat = numpy.arange(-90, 90.5, 10, dtype=numpy.float64)
mx, my = numpy.meshgrid(lon, lat)
distances = pyinterp.geodetic.coordinate_distances(mx.ravel(),
my.ravel(),
mx.ravel() + 1,
my.ravel() + 1,
mx.ravel() + 1.0,
my.ravel() + 1.0,
strategy='vincenty',
wgs=wgs84,
num_threads=1)
Expand Down Expand Up @@ -402,8 +402,8 @@
[-36, -54.9238], [-36.25, -54.9238]]

# %%
lon = numpy.arange(0, 360, 10)
lat = numpy.arange(-80, 90, 10)
lon = numpy.arange(0, 360, 10, dtype=numpy.float64)
lat = numpy.arange(-80, 90, 10, dtype=numpy.float64)
mx, my = numpy.meshgrid(lon, lat)

# %%
Expand Down
20 changes: 12 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# BSD-style license that can be found in the LICENSE file.
"""This script is the entry point for building, distributing and installing
this module using distutils/setuptools."""
from typing import List, Optional, Tuple
from __future__ import annotations

from typing import Any
import datetime
import os
import pathlib
Expand Down Expand Up @@ -32,7 +34,7 @@
OSX_DEPLOYMENT_TARGET = '10.14'


def compare_setuptools_version(required: Tuple[int, ...]) -> bool:
def compare_setuptools_version(required: tuple[int, ...]) -> bool:
"""Compare the version of setuptools with the required version."""
current = tuple(map(int, setuptools.__version__.split('.')[:2]))
return current >= required
Expand Down Expand Up @@ -108,7 +110,8 @@ def revision() -> str:
return match.group(1)
raise AssertionError()

stdout = execute('git describe --tags --dirty --long --always').strip()
stdout: Any = execute(
'git describe --tags --dirty --long --always').strip()
pattern = re.compile(r'([\w\d\.]+)-(\d+)-g([\w\d]+)(?:-(dirty))?')
match = pattern.search(stdout)
if match is None:
Expand Down Expand Up @@ -248,13 +251,13 @@ def run(self) -> None:
self.build_cmake(ext)
super().run()

def boost(self) -> Optional[List[str]]:
def boost(self) -> list[str] | None:
"""Get the default boost path in Anaconda's environment."""
# Do not search system for Boost & disable the search for boost-cmake
boost_option = '-DBoost_NO_SYSTEM_PATHS=TRUE ' \
'-DBoost_NO_BOOST_CMAKE=TRUE'
boost_root = sys.prefix
if pathlib.Path(boost_root, 'include', 'boost').exists():
boost_root = pathlib.Path(sys.prefix)
if (boost_root / 'include' / 'boost').exists():
return f'{boost_option} -DBoost_ROOT={boost_root}'.split()
boost_root = pathlib.Path(sys.prefix, 'Library', 'include')
if not boost_root.exists():
Expand All @@ -265,7 +268,7 @@ def boost(self) -> Optional[List[str]]:
return None
return f'{boost_option} -DBoost_INCLUDE_DIR={boost_root}'.split()

def eigen(self) -> Optional[str]:
def eigen(self) -> str | None:
"""Get the default Eigen3 path in Anaconda's environment."""
eigen_include_dir = pathlib.Path(sys.prefix, 'include', 'eigen3')
if eigen_include_dir.exists():
Expand Down Expand Up @@ -309,8 +312,9 @@ def is_conda() -> bool:
result = True
return result

def set_cmake_user_options(self) -> List[str]:
def set_cmake_user_options(self) -> list[str]:
"""Sets the options defined by the user."""
cmake_variable: Any
is_conda = self.is_conda()
result = []

Expand Down
10 changes: 4 additions & 6 deletions src/pyinterp/_geohash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
Geohash encoding and decoding
-----------------------------
"""
from typing import Optional, Tuple, Type
from __future__ import annotations

#
import numpy
import xarray

#
from . import geodetic
from .core import GeoHash as BaseGeoHash, geohash

Expand Down Expand Up @@ -44,7 +42,7 @@ class GeoHash(BaseGeoHash):

@classmethod
def grid(cls,
box: Optional[geodetic.Box] = None,
box: geodetic.Box | None = None,
precision: int = 1) -> xarray.Dataset:
"""Return the GeoHash grid covering the provided box.
Expand Down Expand Up @@ -85,7 +83,7 @@ def grid(cls,
})

@staticmethod
def from_string(code: str, round: bool = False) -> 'GeoHash':
def from_string(code: str, round: bool = False) -> GeoHash:
"""Create from its string representation.
Args:
Expand All @@ -105,5 +103,5 @@ def __repr__(self) -> str:
lon, lat, precision = super().reduce()
return f'{self.__class__.__name__}({lon}, {lat}, {precision})'

def __reduce__(self) -> Tuple[Type, Tuple[float, float, int]]:
def __reduce__(self) -> tuple[type, tuple[float, float, int]]:
return (self.__class__, super().reduce())
38 changes: 20 additions & 18 deletions src/pyinterp/backends/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
Build interpolation objects from xarray.DataArray instances
"""
from typing import Dict, Hashable, Optional, Tuple, Union
from __future__ import annotations

from typing import Hashable
import pickle

import numpy
Expand All @@ -29,7 +31,7 @@ class AxisIdentifier:
def __init__(self, data_array: xr.DataArray):
self.data_array = data_array

def _axis(self, units: cf.AxisUnit) -> Optional[str]:
def _axis(self, units: cf.AxisUnit) -> str | None:
"""Returns the name of the dimension that defines an axis.
Args:
Expand All @@ -40,18 +42,18 @@ def _axis(self, units: cf.AxisUnit) -> Optional[str]:
"""
for name, coord in self.data_array.coords.items():
if hasattr(coord, 'units') and coord.units in units:
return name
return name # type: ignore
return None

def longitude(self) -> Optional[str]:
def longitude(self) -> str | None:
"""Returns the name of the dimension that defines a longitude axis.
Returns:
The name of the longitude coordinate
"""
return self._axis(cf.AxisLongitudeUnit())

def latitude(self) -> Optional[str]:
def latitude(self) -> str | None:
"""Returns the name of the dimension that defines a latitude axis.
Returns:
Expand All @@ -62,7 +64,7 @@ def latitude(self) -> Optional[str]:

def _dims_from_data_array(data_array: xr.DataArray,
geodetic: bool,
ndims: Optional[int] = 2) -> Tuple[str, str]:
ndims: int | None = 2) -> tuple[str, str]:
"""Gets the name of the dimensions that define the grid axes. the
longitudes and latitudes of the data array.
Expand All @@ -88,7 +90,7 @@ def _dims_from_data_array(data_array: xr.DataArray,
f'{ndims}, found {size}.')

if not geodetic:
return tuple(data_array.coords)[:2]
return tuple(data_array.coords)[:2] # type: ignore

ident = AxisIdentifier(data_array)
lon = ident.longitude()
Expand All @@ -102,9 +104,9 @@ def _dims_from_data_array(data_array: xr.DataArray,

def _coords(
coords: dict,
dims: Tuple,
datetime64: Optional[Tuple[Hashable, core.TemporalAxis]] = None,
) -> Tuple:
dims: tuple,
datetime64: tuple[Hashable, core.TemporalAxis] | None = None,
) -> tuple:
"""Get the list of arguments to provide to the grid interpolation
functions.
Expand Down Expand Up @@ -254,6 +256,7 @@ def __init__(self,
self._dims = (x, y, z)
# Should the grid manage a time axis?
dtype = data_array.coords[z].dtype
self._datetime64: tuple[Hashable, core.TemporalAxis] | None
if 'datetime64' in dtype.name or 'timedelta64' in dtype.name:
self._datetime64 = z, core.TemporalAxis(
data_array.coords[z].values)
Expand Down Expand Up @@ -449,9 +452,8 @@ def __init__(self,
increasing_axes: bool = True,
geodetic: bool = True):
if len(array.shape) == 2:
self._grid = Grid2D(array,
increasing_axes=increasing_axes,
geodetic=geodetic)
self._grid: (Grid2D | Grid3D | Grid4D) = Grid2D(
array, increasing_axes=increasing_axes, geodetic=geodetic)
self._interp = self._grid.bivariate
elif len(array.shape) == 3:
self._grid = Grid3D(array,
Expand All @@ -467,13 +469,13 @@ def __init__(self,
raise NotImplementedError(
'Only the 2D, 3D or 4D grids can be interpolated.')

def __getstate__(self) -> Tuple[bytes]:
def __getstate__(self) -> tuple[bytes]:
# Walk around a bug with pybind11 and pickle starting with Python 3.9
# Serialize the object here with highest protocol.
return (pickle.dumps((self._grid, self._interp),
protocol=pickle.HIGHEST_PROTOCOL), )

def __setstate__(self, state: Tuple[bytes]) -> None:
def __setstate__(self, state: tuple[bytes]) -> None:
# Walk around a bug with pybind11 and pickle starting with Python 3.9
# Deserialize the object here with highest protocol.
self._grid, self._interp = pickle.loads(state[0])
Expand All @@ -488,7 +490,7 @@ def ndim(self) -> int:
return self._grid.array.ndim

@property
def grid(self) -> Union[Grid2D, Grid3D, Grid4D]:
def grid(self) -> Grid2D | Grid3D | Grid4D:
"""Gets the instance of handling the regular grid for interpolations.
Returns:
Expand All @@ -497,10 +499,10 @@ def grid(self) -> Union[Grid2D, Grid3D, Grid4D]:
return self._grid

def __call__(self,
coords: Dict,
coords: dict,
method: str = 'bilinear',
bounds_error: bool = False,
bicubic_kwargs: Optional[Dict] = None,
bicubic_kwargs: dict | None = None,
num_threads: int = 0,
**kwargs) -> numpy.ndarray:
"""Interpolation at coordinates.
Expand Down
39 changes: 23 additions & 16 deletions src/pyinterp/binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@
Data binning
------------
"""
from typing import Optional, Tuple, Union
from __future__ import annotations

from typing import Union
import copy

import dask.array.core
import numpy

from . import core, geodetic

#: The supported data types for the binning 2D
Binning2DTyped = Union[core.Binning2DFloat64, core.Binning2DFloat32]


class Binning2D:
"""Group a number of more or less continuous values into a smaller number
Expand Down Expand Up @@ -46,10 +51,10 @@ class Binning2D:
def __init__(self,
x: core.Axis,
y: core.Axis,
wgs: Optional[geodetic.Spheroid] = None,
wgs: geodetic.Spheroid | None = None,
dtype: numpy.dtype = numpy.dtype('float64')):
if dtype == numpy.dtype('float64'):
self._instance = core.Binning2DFloat64(x, y, wgs)
self._instance: Binning2DTyped = core.Binning2DFloat64(x, y, wgs)
elif dtype == numpy.dtype('float32'):
self._instance = core.Binning2DFloat32(x, y, wgs)
else:
Expand All @@ -67,7 +72,7 @@ def y(self) -> core.Axis:
return self._instance.y

@property
def wgs(self) -> Optional[core.geodetic.Spheroid]:
def wgs(self) -> core.geodetic.Spheroid | None:
"""Gets the geodetic system handled of the grid."""
return self._instance.wgs

Expand All @@ -84,7 +89,7 @@ def __repr__(self) -> str:
result.append(f' y: {self._instance.y}')
return '\n'.join(result)

def __add__(self, other: 'Binning2D') -> 'Binning2D':
def __add__(self, other: Binning2D) -> Binning2D:
"""Overrides the default behavior of the ``+`` operator."""
result = copy.copy(self)
if type(result._instance) != type(other._instance): # noqa: E721
Expand Down Expand Up @@ -146,9 +151,9 @@ def push(self,
self._instance.push(x, y, z, simple)

def push_delayed(self,
x: Union[numpy.ndarray, dask.array.core.Array],
y: Union[numpy.ndarray, dask.array.core.Array],
z: Union[numpy.ndarray, dask.array.core.Array],
x: numpy.ndarray | dask.array.core.Array,
y: numpy.ndarray | dask.array.core.Array,
z: numpy.ndarray | dask.array.core.Array,
simple: bool = True) -> dask.array.core.Array:
"""Push new samples into the defined bins from dask array.
Expand Down Expand Up @@ -239,10 +244,12 @@ class Binning1D:

def __init__(self,
x: core.Axis,
range: Optional[Tuple[float, float]] = None,
range: tuple[float, float] | None = None,
dtype: numpy.dtype = numpy.dtype('float64')):
if dtype == numpy.dtype('float64'):
self._instance = core.Binning1DFloat64(x, range)
self._instance: (core.Binning1DFloat64
| core.Binning1DFloat32) = core.Binning1DFloat64(
x, range)
elif dtype == numpy.dtype('float32'):
self._instance = core.Binning1DFloat32(x, range)
else:
Expand All @@ -254,7 +261,7 @@ def x(self) -> core.Axis:
"""Gets the bin centers for the X Axis of the grid."""
return self._instance.x

def range(self) -> Tuple[float, float]:
def range(self) -> tuple[float, float]:
"""Gets the lower and upper range of the bins."""
return self._instance.range()

Expand All @@ -272,7 +279,7 @@ def __repr__(self) -> str:
result.append(f' {self._instance.range()}')
return '\n'.join(result)

def __add__(self, other: 'Binning1D') -> 'Binning1D':
def __add__(self, other: Binning1D) -> Binning1D:
"""Overrides the default behavior of the ``+`` operator."""
result = copy.copy(self)
if type(result._instance) != type(other._instance): # noqa: E721
Expand All @@ -284,7 +291,7 @@ def push(
self,
x: numpy.ndarray,
z: numpy.ndarray,
weights: Optional[numpy.ndarray] = None,
weights: numpy.ndarray | None = None,
) -> None:
"""Push new samples into the defined bins.
Expand All @@ -301,9 +308,9 @@ def push(

def push_delayed(
self,
x: Union[numpy.ndarray, dask.array.core.Array],
z: Union[numpy.ndarray, dask.array.core.Array],
weights: Optional[Union[numpy.ndarray, dask.array.core.Array]] = None,
x: numpy.ndarray | dask.array.core.Array,
z: numpy.ndarray | dask.array.core.Array,
weights: numpy.ndarray | dask.array.core.Array | None = None,
) -> dask.array.core.Array:
"""Push new samples into the defined bins from dask array.
Expand Down
Loading

0 comments on commit 6a5d488

Please sign in to comment.