Skip to content

Commit

Permalink
✨ RegularGridInterpolator does not handle interpolator keywords.
Browse files Browse the repository at this point in the history
  • Loading branch information
fbriol committed Nov 6, 2023
1 parent 39905ba commit 72ce058
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/pyinterp/backends/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,8 @@ def __call__(self,
method: str = 'bilinear',
bounds_error: bool = False,
bicubic_kwargs: Optional[Dict] = None,
num_threads: int = 0) -> numpy.ndarray:
num_threads: int = 0,
**kwargs) -> numpy.ndarray:
"""Interpolation at coordinates.
Args:
Expand All @@ -520,6 +521,11 @@ def __call__(self,
num_threads: The number of threads to use for the computation. If 0
all CPUs are used. If 1 is given, no parallel computing code is
used at all, which is useful for debugging. Defaults to ``0``.
**kwargs: List of keyword arguments provided to the interpolation
method :py:meth:`pyinterp.bivariate <pyinterp.bivariate>`,
:py:meth:`pyinterp.trivariate <pyinterp.trivariate>` or
:py:meth:`pyinterp.quadrivariate <pyinterp.quadrivariate>`
depending on the number of dimensions of the grid.
Returns:
New array on the new coordinates.
"""
Expand All @@ -532,4 +538,5 @@ def __call__(self,
return self._interp(coords,
interpolator=method,
bounds_error=bounds_error,
num_threads=num_threads)
num_threads=num_threads,
**kwargs)
9 changes: 9 additions & 0 deletions src/pyinterp/tests/test_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@ def test_biavariate(pytestconfig):
z = grid(collections.OrderedDict(lon=x.ravel(), lat=y.ravel()),
method='bilinear')
assert isinstance(z, np.ndarray)
z = grid(collections.OrderedDict(lon=x.ravel(), lat=y.ravel()),
method='inverse_distance_weighting',
p=1)
assert isinstance(z, np.ndarray)

with pytest.raises(TypeError):
z = grid(collections.OrderedDict(lon=x.ravel(), lat=y.ravel()),
method='nearest',
p=1)

# This is necessary in order for Dask to scatter the callable instances.
other = pickle.loads(pickle.dumps(grid, protocol=0))
Expand Down

0 comments on commit 72ce058

Please sign in to comment.