Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: use sphinx_autodoc_typehints #386

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions adaptive/learner/learner1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from copy import copy, deepcopy
from numbers import Integral as Int
from numbers import Real
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Sequence, Tuple, Union

import cloudpickle
import numpy as np
Expand All @@ -24,12 +24,22 @@
partial_function_from_dataframe,
)

if TYPE_CHECKING:
import holoviews

try:
from typing import TypeAlias
except ImportError:
# Remove this when we drop support for Python 3.9
from typing_extensions import TypeAlias

try:
from typing import Literal
except ImportError:
# Remove this when we drop support for Python 3.7
from typing_extensions import Literal


Comment on lines +36 to +42
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can drop 3.7, see https://numpy.org/neps/nep-0029-deprecation_policy.html

We didn't discuss this, but NEP 29 is a good source.

try:
import pandas

Expand Down Expand Up @@ -145,7 +155,7 @@ def resolution_loss_function(

Returns
-------
loss_function : callable
loss_function

Examples
--------
Expand Down Expand Up @@ -230,12 +240,12 @@ class Learner1D(BaseLearner):

Parameters
----------
function : callable
function
The function to learn. Must take a single real parameter and
return a real number or 1D array.
bounds : pair of reals
bounds
The bounds of the interval on which to learn 'function'.
loss_per_interval: callable, optional
loss_per_interval
A function that returns the loss for a single interval of the domain.
If not provided, then a default is used, which uses the scaled distance
in the x-y plane as the loss. See the notes for more details.
Expand Down Expand Up @@ -356,15 +366,15 @@ def to_dataframe(

Parameters
----------
with_default_function_args : bool, optional
with_default_function_args
Include the ``learner.function``'s default arguments as a
column, by default True
function_prefix : str, optional
function_prefix
Prefix to the ``learner.function``'s default arguments' names,
by default "function."
x_name : str, optional
x_name
Name of the input value, by default "x"
y_name : str, optional
y_name
Name of the output value, by default "y"

Returns
Expand Down Expand Up @@ -403,16 +413,16 @@ def load_dataframe(

Parameters
----------
df : pandas.DataFrame
df
The data to load.
with_default_function_args : bool, optional
with_default_function_args
The ``with_default_function_args`` used in ``to_dataframe()``,
by default True
function_prefix : str, optional
function_prefix
The ``function_prefix`` used in ``to_dataframe``, by default "function."
x_name : str, optional
x_name
The ``x_name`` used in ``to_dataframe``, by default "x"
y_name : str, optional
y_name
The ``y_name`` used in ``to_dataframe``, by default "y"
"""
self.tell_many(df[x_name].values, df[y_name].values)
Expand Down Expand Up @@ -795,17 +805,19 @@ def _loss(
loss = mapping[ival]
return finite_loss(ival, loss, self._scale[0])

def plot(self, *, scatter_or_line: str = "scatter"):
def plot(
self, *, scatter_or_line: Literal["scatter", "line"] = "scatter"
) -> holoviews.Overlay:
"""Returns a plot of the evaluated data.

Parameters
----------
scatter_or_line : str, default: "scatter"
scatter_or_line
Plot as a scatter plot ("scatter") or a line plot ("line").

Returns
-------
plot : `holoviews.Overlay`
plot
Plot of the evaluated data.
"""
if scatter_or_line not in ("scatter", "line"):
Expand Down
85 changes: 45 additions & 40 deletions adaptive/learner/learner2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import OrderedDict
from copy import copy
from math import sqrt
from typing import Callable, Iterable
from typing import TYPE_CHECKING, Callable, Iterable

import cloudpickle
import numpy as np
Expand All @@ -22,6 +22,9 @@
partial_function_from_dataframe,
)

if TYPE_CHECKING:
import holoviews

try:
import pandas

Expand All @@ -40,11 +43,11 @@ def deviations(ip: LinearNDInterpolator) -> list[np.ndarray]:

Parameters
----------
ip : `scipy.interpolate.LinearNDInterpolator` instance
ip

Returns
-------
deviations : list
deviations
The deviation per triangle.
"""
values = ip.values / (ip.values.ptp(axis=0).max() or 1)
Expand Down Expand Up @@ -79,11 +82,11 @@ def areas(ip: LinearNDInterpolator) -> np.ndarray:

Parameters
----------
ip : `scipy.interpolate.LinearNDInterpolator` instance
ip

Returns
-------
areas : numpy.ndarray
areas
The area per triangle in ``ip.tri``.
"""
p = ip.tri.points[ip.tri.simplices]
Expand All @@ -99,11 +102,11 @@ def uniform_loss(ip: LinearNDInterpolator) -> np.ndarray:

Parameters
----------
ip : `scipy.interpolate.LinearNDInterpolator` instance
ip

Returns
-------
losses : numpy.ndarray
losses
Loss per triangle in ``ip.tri``.

Examples
Expand Down Expand Up @@ -136,7 +139,7 @@ def resolution_loss_function(

Returns
-------
loss_function : callable
loss_function

Examples
--------
Expand Down Expand Up @@ -173,11 +176,11 @@ def minimize_triangle_surface_loss(ip: LinearNDInterpolator) -> np.ndarray:

Parameters
----------
ip : `scipy.interpolate.LinearNDInterpolator` instance
ip

Returns
-------
losses : numpy.ndarray
losses
Loss per triangle in ``ip.tri``.

Examples
Expand Down Expand Up @@ -217,11 +220,11 @@ def default_loss(ip: LinearNDInterpolator) -> np.ndarray:

Parameters
----------
ip : `scipy.interpolate.LinearNDInterpolator` instance
ip

Returns
-------
losses : numpy.ndarray
losses
Loss per triangle in ``ip.tri``.
"""
dev = np.sum(deviations(ip), axis=0)
Expand All @@ -241,15 +244,15 @@ def choose_point_in_triangle(triangle: np.ndarray, max_badness: int) -> np.ndarr

Parameters
----------
triangle : numpy.ndarray
triangle
The coordinates of a triangle with shape (3, 2).
max_badness : int
max_badness
The badness at which the point is either chosen on a edge or
in the middle.

Returns
-------
point : numpy.ndarray
point
The x and y coordinate of the suggested new point.
"""
a, b, c = triangle
Expand All @@ -267,17 +270,17 @@ def choose_point_in_triangle(triangle: np.ndarray, max_badness: int) -> np.ndarr
return point


def triangle_loss(ip):
def triangle_loss(ip: LinearNDInterpolator) -> list[float]:
r"""Computes the average of the volumes of the simplex combined with each
neighbouring point.

Parameters
----------
ip : `scipy.interpolate.LinearNDInterpolator` instance
ip

Returns
-------
triangle_loss : list
triangle_loss
The mean volume per triangle.

Notes
Expand Down Expand Up @@ -311,13 +314,13 @@ class Learner2D(BaseLearner):

Parameters
----------
function : callable
function
The function to learn. Must take a tuple of two real
parameters and return a real number.
bounds : list of 2-tuples
bounds
A list ``[(a1, b1), (a2, b2)]`` containing bounds,
one per dimension.
loss_per_triangle : callable, optional
loss_per_triangle
A function that returns the loss for every triangle.
If not provided, then a default is used, which uses
the deviation from a linear estimate, as well as
Expand Down Expand Up @@ -424,19 +427,19 @@ def to_dataframe(

Parameters
----------
with_default_function_args : bool, optional
with_default_function_args
Include the ``learner.function``'s default arguments as a
column, by default True
function_prefix : str, optional
function_prefix
Prefix to the ``learner.function``'s default arguments' names,
by default "function."
seed_name : str, optional
seed_name
Name of the seed parameter, by default "seed"
x_name : str, optional
x_name
Name of the input x value, by default "x"
y_name : str, optional
y_name
Name of the input y value, by default "y"
z_name : str, optional
z_name
Name of the output value, by default "z"

Returns
Expand Down Expand Up @@ -475,18 +478,18 @@ def load_dataframe(

Parameters
----------
df : pandas.DataFrame
df
The data to load.
with_default_function_args : bool, optional
with_default_function_args
The ``with_default_function_args`` used in ``to_dataframe()``,
by default True
function_prefix : str, optional
function_prefix
The ``function_prefix`` used in ``to_dataframe``, by default "function."
x_name : str, optional
x_name
The ``x_name`` used in ``to_dataframe``, by default "x"
y_name : str, optional
y_name
The ``y_name`` used in ``to_dataframe``, by default "y"
z_name : str, optional
z_name
The ``z_name`` used in ``to_dataframe``, by default "z"
"""
data = df.set_index([x_name, y_name])[z_name].to_dict()
Expand Down Expand Up @@ -538,7 +541,7 @@ def interpolated_on_grid(

Parameters
----------
n : int, optional
n
Number of points in x and y. If None (default) this number is
evaluated by looking at the size of the smallest triangle.

Expand Down Expand Up @@ -611,14 +614,14 @@ def interpolator(self, *, scaled: bool = False) -> LinearNDInterpolator:

Parameters
----------
scaled : bool
scaled
Use True if all points are inside the
unit-square [(-0.5, 0.5), (-0.5, 0.5)] or False if
the data points are inside the ``learner.bounds``.

Returns
-------
interpolator : `scipy.interpolate.LinearNDInterpolator`
interpolator

Examples
--------
Expand Down Expand Up @@ -755,7 +758,9 @@ def remove_unfinished(self) -> None:
if p not in self.data:
self._stack[p] = np.inf

def plot(self, n=None, tri_alpha=0):
def plot(
self, n: int = None, tri_alpha: float = 0
) -> holoviews.Overlay | holoviews.HoloMap:
r"""Plot the Learner2D's current state.

This plot function interpolates the data on a regular grid.
Expand All @@ -764,16 +769,16 @@ def plot(self, n=None, tri_alpha=0):

Parameters
----------
n : int
n
Number of points in x and y. If None (default) this number is
evaluated by looking at the size of the smallest triangle.
tri_alpha : float
tri_alpha
The opacity ``(0 <= tri_alpha <= 1)`` of the triangles overlayed
on top of the image. By default the triangulation is not visible.

Returns
-------
plot : `holoviews.core.Overlay` or `holoviews.core.HoloMap`
plot
A `holoviews.core.Overlay` of
``holoviews.Image * holoviews.EdgePaths``. If the
`learner.function` returns a vector output, a
Expand Down
Loading