Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixup mypy matplotlib #575

Merged
merged 5 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions openfe/analysis/plotting.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
import numpy.typing as npt
from openff.units import unit
from typing import Optional, Union


def plot_lambda_transition_matrix(matrix: npt.NDArray) -> plt.Axes:
def plot_lambda_transition_matrix(matrix: npt.NDArray) -> Axes:
"""
Plot out a transition matrix.

Expand All @@ -17,7 +18,7 @@ def plot_lambda_transition_matrix(matrix: npt.NDArray) -> plt.Axes:

Returns
-------
ax : matplotlib.pyplot.Axes
ax : matplotlib.axes.Axes
An Axes object to plot.
"""
num_states = len(matrix)
Expand Down Expand Up @@ -79,7 +80,7 @@ def plot_lambda_transition_matrix(matrix: npt.NDArray) -> plt.Axes:
def plot_convergence(
forward_and_reverse: dict[str, Union[npt.NDArray, unit.Quantity]],
units: unit.Quantity
) -> plt.Axes:
) -> Axes:
"""
Plot a Reverse and Forward convergence analysis of the
free energies.
Expand All @@ -95,7 +96,7 @@ def plot_convergence(

Returns
-------
ax : matplotlib.pyplot.Axes
ax : matplotlib.axes.Axes
An Axes object to plot.
"""
known_units = {
Expand Down Expand Up @@ -165,7 +166,7 @@ def plot_convergence(
def plot_replica_timeseries(
state_timeseries: npt.NDArray,
equilibration_iterations: Optional[int] = None,
) -> plt.Axes:
) -> Axes:
"""
Plot a the state timeseries of a set of replicas.

Expand All @@ -178,7 +179,7 @@ def plot_replica_timeseries(

Returns
-------
ax : matplotlib.pyplot.Axes
ax : matplotlib.axes.Axes
An Axes object to plot.
"""
num_states = len(state_timeseries.T)
Expand Down
8 changes: 4 additions & 4 deletions openfe/protocols/openmm_utils/multistate_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def plot(self, filepath: Path, filename_prefix: str):
# MBAR overlap matrix
ax = plotting.plot_lambda_transition_matrix(self.free_energy_overlaps['matrix'])
ax.set_title('MBAR overlap matrix')
ax.figure.savefig(
ax.figure.savefig( # type: ignore
filepath / (filename_prefix + 'mbar_overlap_matrix.png')
)

Expand All @@ -83,7 +83,7 @@ def plot(self, filepath: Path, filename_prefix: str):
self.forward_and_reverse_free_energies, self.units
)
ax.set_title('Forward and Reverse free energy convergence')
ax.figure.savefig(
ax.figure.savefig( # type: ignore
filepath / (filename_prefix + 'forward_reverse_convergence.png')
)

Expand All @@ -92,7 +92,7 @@ def plot(self, filepath: Path, filename_prefix: str):
self.replica_states, self.equilibration_iterations
)
ax.set_title('Change in replica state over time')
ax.figure.savefig(
ax.figure.savefig( # type: ignore
filepath / (filename_prefix + 'replica_state_timeseries.png')
)

Expand All @@ -102,7 +102,7 @@ def plot(self, filepath: Path, filename_prefix: str):
self.replica_exchange_statistics['matrix']
)
ax.set_title('Replica exchange transition matrix')
ax.figure.savefig(
ax.figure.savefig( # type: ignore
filepath / (filename_prefix + 'replica_exchange_matrix.png')
)

Expand Down
34 changes: 17 additions & 17 deletions openfe/utils/network_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from matplotlib.patches import Rectangle
from matplotlib.lines import Line2D

from typing import Dict, List, Tuple, Optional, Any, Union, cast
from typing import Optional, Any, Union, cast
from openfe.utils.custom_typing import (
MPL_MouseEvent, MPL_FigureCanvasBase, MPL_Axes, TypeAlias
)

ClickLocation: TypeAlias = Tuple[Tuple[float, float], Tuple[Any, Any]]
ClickLocation: TypeAlias = tuple[tuple[float, float], tuple[Any, Any]]


class Node:
Expand Down Expand Up @@ -54,14 +54,14 @@ def register_artist(self, ax: MPL_Axes):
ax.add_patch(self.artist)

@property
def extent(self) -> Tuple[float, float, float, float]:
def extent(self) -> tuple[float, float, float, float]:
"""extent of this node in matplotlib data coordinates"""
bounds = self.artist.get_bbox().bounds
return (bounds[0], bounds[0] + bounds[2],
bounds[1], bounds[1] + bounds[3])

@property
def xy(self) -> Tuple[float, float]:
def xy(self) -> tuple[float, float]:
"""lower left (matplotlib data coordinates) position of this node"""
return self.artist.xy

Expand Down Expand Up @@ -153,14 +153,14 @@ class Edge:
"""
pickable = True

def __init__(self, node_artist1: Node, node_artist2: Node, data: Dict):
def __init__(self, node_artist1: Node, node_artist2: Node, data: dict):
self.data = data
self.node_artists = [node_artist1, node_artist2]
self.artist = self._make_artist(node_artist1, node_artist2, data)
self.picked = False

def _make_artist(self, node_artist1: Node, node_artist2: Node,
data: Dict) -> Any:
data: dict) -> Any:
xs, ys = self._edge_xs_ys(node_artist1, node_artist2)
return Line2D(xs, ys, color='black', picker=True, zorder=-1)

Expand Down Expand Up @@ -238,7 +238,7 @@ class EventHandler:
selected : Optional[Union[Node, Edge]]
Object selected by a mouse click (after mouse is up), or None if no
object has been selected in the graph.
click_location : Optional[Tuple[int, int]]
click_location : Optional[tuple[Optional[float], Optional[float]]]
Cached location of the mousedown event, or None if mouse is up
connections : List[int]
list of IDs for connections to matplotlib canvas
Expand All @@ -247,15 +247,15 @@ def __init__(self, graph: GraphDrawing):
self.graph = graph
self.active: Optional[Union[Node, Edge]] = None
self.selected: Optional[Union[Node, Edge]] = None
self.click_location: Optional[Tuple[int, int]] = None
self.connections: List[int] = []
self.click_location: Optional[tuple[Optional[float], Optional[float]]] = None
self.connections: list[int] = []

def connect(self, canvas: MPL_FigureCanvasBase):
"""Connect our methods to events in the matplotlib canvas"""
self.connections.extend([
canvas.mpl_connect('button_press_event', self.on_mousedown),
canvas.mpl_connect('motion_notify_event', self.on_drag),
canvas.mpl_connect('button_release_event', self.on_mouseup),
canvas.mpl_connect('button_press_event', self.on_mousedown), # type: ignore
canvas.mpl_connect('motion_notify_event', self.on_drag), # type: ignore
canvas.mpl_connect('button_release_event', self.on_mouseup), # type: ignore
])

def disconnect(self, canvas: MPL_FigureCanvasBase):
Expand Down Expand Up @@ -346,8 +346,8 @@ def __init__(self, graph: nx.Graph, positions=None, ax=None):
# TODO: use scale to scale up the positions?
self.event_handler = EventHandler(self)
self.graph = graph
self.nodes: Dict[Node, Any] = {}
self.edges: Dict[Tuple[Node, Node], Any] = {}
self.nodes: dict[Node, Any] = {}
self.edges: dict[tuple[Node, Node], Any] = {}

if positions is None:
positions = nx.nx_agraph.graphviz_layout(self.graph, prog='neato')
Expand Down Expand Up @@ -378,7 +378,7 @@ def __init__(self, graph: nx.Graph, positions=None, ax=None):
def _ipython_display_(self): # -no-cov-
return self.fig

def edges_for_node(self, node: Node) -> List[Edge]:
def edges_for_node(self, node: Node) -> list[Edge]:
"""List of edges for the given node"""
edges = (list(self.graph.in_edges(node))
+ list(self.graph.out_edges(node)))
Expand Down Expand Up @@ -410,7 +410,7 @@ def draw(self):
self.fig.canvas.draw()
self.fig.canvas.flush_events()

def _register_node(self, node: Any, position: Tuple[float, float]):
def _register_node(self, node: Any, position: tuple[float, float]):
"""Create and register ``Node`` from NetworkX node and position"""
if node in self.nodes:
raise RuntimeError("node provided multiple times")
Expand All @@ -419,7 +419,7 @@ def _register_node(self, node: Any, position: Tuple[float, float]):
self.nodes[node] = draw_node
draw_node.register_artist(self.ax)

def _register_edge(self, edge: Tuple[Node, Node, Dict]):
def _register_edge(self, edge: tuple[Node, Node, dict]):
"""Create and register ``Edge`` from NetworkX edge information"""
node1, node2, data = edge
draw_edge = self.EdgeCls(self.nodes[node1], self.nodes[node2], data)
Expand Down