Skip to content

Commit

Permalink
style: misc style fixes and missing imports
Browse files Browse the repository at this point in the history
  • Loading branch information
mdtanker committed Jul 18, 2024
1 parent 05c05d3 commit f5c70fe
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 30 deletions.
7 changes: 6 additions & 1 deletion src/invert4geom/cross_validation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations
from __future__ import annotations # pylint: disable=too-many-lines

import itertools
import logging
Expand All @@ -8,13 +8,18 @@
import typing

import deprecation
import harmonica as hm
import numpy as np
import pandas as pd
import sklearn
import verde as vd
import xarray as xr
from nptyping import NDArray
from polartoolkit import maps
from polartoolkit import utils as polar_utils
from tqdm.autonotebook import tqdm

import invert4geom
from invert4geom import inversion, plotting, regional, utils


Expand Down
5 changes: 2 additions & 3 deletions src/invert4geom/inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,8 +984,8 @@ def run_inversion(
# update the l2 and delta l2 norms
previous_delta_l2_norm = copy.copy(delta_l2_norm)
l2_norm, delta_l2_norm = update_l2_norms(
current_rmse = updated_rmse,
last_l2_norm = l2_norm,
current_rmse=updated_rmse,
last_l2_norm=l2_norm,
)
final_l2_norm = l2_norm

Expand Down Expand Up @@ -1431,7 +1431,6 @@ def run_inversion_workflow( # equivalent to monte_carlo_full_workflow
# use the best damping parameter
inversion_kwargs["solver_damping"] = best_damping


if run_zref_or_density_cv is False:
if fname is not None:
# save results to pickle
Expand Down
21 changes: 12 additions & 9 deletions src/invert4geom/optimization.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
from __future__ import annotations
from __future__ import annotations # pylint: disable=too-many-lines

import itertools
import logging
import math
import multiprocessing
import os
import pathlib
import pickle
import random
import re
import subprocess
import typing
import warnings

import harmonica as hm
import optuna
import numpy as np
import optuna
import pandas as pd
import xarray as xr
from nptyping import NDArray
from tqdm.autonotebook import tqdm

from invert4geom import plotting, utils

from invert4geom import cross_validation, plotting, regional, utils

try:
import joblib
Expand Down Expand Up @@ -346,14 +349,14 @@ def optuna_parallel(

# set up parallel processing and run optimization
if parallel is True:
# @utils.supress_stdout

def optimize_study(
study_name: str,
storage: typing.Any,
storage: optuna.storages.BaseStorage,
objective: typing.Callable[..., float],
n_trials: int,
) -> None:
storage: optuna.storages.BaseStorage,
study = optuna.load_study(study_name=study_name, storage=storage)
optuna.logging.set_verbosity(optuna.logging.WARNING)
study.optimize(
objective,
Expand Down Expand Up @@ -399,11 +402,11 @@ def optuna_max_cores(
n_trials: int,
optimize_study: typing.Callable[..., None],
study_name: str,
study_storage: typing.Any,
study_storage: optuna.storages.BaseStorage,
objective: typing.Callable[..., float],
) -> None:
"""
study_storage: optuna.storages.BaseStorage,
Set up optuna optimization in parallel splitting up the number of trials over all
available cores.
"""
if joblib is None:
Expand Down
36 changes: 22 additions & 14 deletions src/invert4geom/plotting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations # pylint: disable=too-many-lines

import logging
import typing

import numpy as np
Expand All @@ -17,8 +18,8 @@
clear_output = None

try:
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
except ImportError:
plt = None

Expand All @@ -32,8 +33,10 @@
import pyvista
except ImportError:
pyvista = None

import verde as vd
import xarray as xr
from polartoolkit import maps
from polartoolkit import utils as polar_utils

from invert4geom import utils
Expand Down Expand Up @@ -363,7 +366,7 @@ def plot_convergence(
if i == 0:
delta_l2_norms.append(np.nan)
else:
delta_l2_norms.append(l2_norms[i-1]/m)
delta_l2_norms.append(l2_norms[i - 1] / m)

# get tolerance values
l2_norm_tolerance = float(params["L2 norm tolerance"])
Expand Down Expand Up @@ -478,11 +481,11 @@ def plot_dynamic_convergence(
ax2 = ax1.twinx()

# plot L2-norm convergence
ax1.plot([i for i in range(len(l2_norms))], l2_norms, "b-")
ax1.plot(list(range(len(l2_norms))), l2_norms, "b-")

# plot delta L2-norm convergence
if iters > 1:
ax2.plot([i for i in range(len(delta_l2_norms))], delta_l2_norms, "g-")
ax2.plot(list(range(len(delta_l2_norms))), delta_l2_norms, "g-")

# set axis labels, ticks and gridlines
ax1.set_xlabel("Iteration")
Expand All @@ -502,7 +505,7 @@ def plot_dynamic_convergence(

# plot current L2-norm and Δ L2-norm
ax1.plot(
iters-1,
iters - 1,
l2_norms[-1],
"^",
markersize=6,
Expand All @@ -511,7 +514,7 @@ def plot_dynamic_convergence(
)
if iters > 1:
ax2.plot(
iters-1,
iters - 1,
delta_l2_norms[-1],
"^",
markersize=6,
Expand Down Expand Up @@ -541,17 +544,22 @@ def plot_dynamic_convergence(
plt.show()


def align_yaxis(ax1, v1, ax2, v2):
def align_yaxis(
ax1: mpl.axes.Axes,
v1: float,
ax2: mpl.axes.Axes,
v2: float,
) -> None:
"""
adjust ax2 ylimit so that v2 in ax2 is aligned to v1 in ax1.
From https://stackoverflow.com/a/10482477/18686384
"""
_, y1 = ax1.transData.transform((0, v1))
_, y2 = ax2.transData.transform((0, v2))
inv = ax2.transData.inverted()
_, dy = inv.transform((0, 0)) - inv.transform((0, y1-y2))
_, dy = inv.transform((0, 0)) - inv.transform((0, y1 - y2))
miny, maxy = ax2.get_ylim()
ax2.set_ylim(miny+dy, maxy+dy)
ax2.set_ylim(miny + dy, maxy + dy)


def grid_inversion_results(
Expand Down Expand Up @@ -1211,7 +1219,7 @@ def show_prism_layers(
prisms = [prisms]

for i, j in enumerate(prisms):
# turn prisms into pyvist object
# turn prisms into pyvista object
pv_grid = j.prism_layer.to_pyvista()

trans = opacity[i] if opacity is not None else None
Expand Down Expand Up @@ -1263,7 +1271,7 @@ def combined_history(
study: optuna.study.Study,
target_names: list[str],
include_duration: bool = False,
) -> typing.Any:
) -> plotly.graph_objects.Figure:
"""
plot combined optimization history for multiobjective optimizations.
Expand All @@ -1278,7 +1286,7 @@ def combined_history(
Returns
-------
typing.Any
plotly.graph_objects.Figure
a plotly figure
"""

Expand Down Expand Up @@ -1361,7 +1369,7 @@ def plot_optuna_figures(
target_names: list[str],
include_duration: bool = False,
# params=None,
# seperate_param_importances=False,
# separate_param_importances=False,
plot_history: bool = True,
plot_slice: bool = True,
plot_importance: bool = True,
Expand Down Expand Up @@ -1417,7 +1425,7 @@ def plot_optuna_figures(
# pass
# else:
# try:
# if seperate_param_importances is True:
# if separate_param_importances is True:
# combined_importance(
# study,
# target_names,
Expand Down
8 changes: 6 additions & 2 deletions src/invert4geom/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings

import dask
import deprecation
import harmonica as hm
import numpy as np
import pandas as pd
Expand All @@ -16,6 +17,9 @@
from nptyping import NDArray
from pykdtree.kdtree import KDTree # pylint: disable=no-name-in-module

import invert4geom
from invert4geom import cross_validation


def rmse(data: NDArray, as_median: bool = False) -> float:
"""
Expand Down Expand Up @@ -1011,8 +1015,8 @@ def best_spline_cv(
current_version=invert4geom.__version__,
details="function eq_sources_score has been moved to the cross_validation model.",
)
def eq_sources_score(kwargs) -> float:
def eq_sources_score(kwargs: typing.Any) -> float:
"""
deprecated function, use cross_validation.eq_sources_score instead.
"""
return cross_validation.eq_sources_score(kwargs)
return cross_validation.eq_sources_score(**kwargs)
1 change: 1 addition & 0 deletions tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
def test_zref_density_optimal_parameter():
cross_validation.zref_density_optimal_parameter()


@deprecation.fail_if_not_removed
def test_grav_optimal_parameter():
cross_validation.grav_optimal_parameter()
1 change: 0 additions & 1 deletion tests/test_regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def test_regional_constraints(test_input):
grav_df=anomalies,
grav_data_column="misfit",
grid_method=test_input,
eqs_gridding_trials=2,
grav_obs_height=1e3,
)

Expand Down
1 change: 1 addition & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import deprecation
import harmonica as hm
import numpy as np
import numpy.testing as npt
Expand Down

0 comments on commit f5c70fe

Please sign in to comment.