Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Converting sisl_toolbox to typer. #608

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 13 additions & 65 deletions src/sisl_toolbox/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,73 +7,21 @@
accessible.
"""

import typer
from ._typer_wrappers import annotate_typer

class SToolBoxCLI:
""" Run the CLI `stoolbox` """
from sisl_toolbox.siesta.atom._atom import atom_plot
from sisl_toolbox.transiesta.poisson.fftpoisson_fix import fftpoisson_fix

def __init__(self):
self._cmds = []
app = typer.Typer(
name="Sisl toolbox",
help="Specific toolboxes to aid sisl users",
rich_markup_mode="markdown",
add_completion=False
)

def register(self, setup):
""" Register a setup callback function which creates the subparser
app.command()(annotate_typer(atom_plot))
app.command("ts-fft")(annotate_typer(fftpoisson_fix))

The ``setup(..)`` command must accept a sub-parser from `argparse` as its
first argument.
stoolbox_cli = app

The only requirements to create a sub-command is to fullfill these requirements:

1. Create a new parser using ``subp.add_parser``.
2. Ensure a runner is attached to the subparser through ``.set_defaults(runner=<callable>)``

A minimal example would be:

>>> def setup(subp):
... p = subp.add_parser("test-sub")
... def test_sub_method(args):
... print(args)
... p.set_defaults(runner=test_sub_method)
"""
self._cmds.append(setup)

def __call__(self, argv=None):
import argparse
import sys
from pathlib import Path

# Create command-line
cmd = Path(sys.argv[0])
p = argparse.ArgumentParser(f"{cmd.name}",
description="Specific toolboxes to aid sisl users")

info = {
"title": "Toolboxes",
"metavar": "TOOL",
}

# Check which Python version we have
version = sys.version_info
if version.major >= 3 and version.minor >= 7:
info["required"] = True

# Create the sub-parser
subp = p.add_subparsers(**info)

for cmd in self._cmds:
cmd(subp)

args = p.parse_args(argv)
args.runner(args)


# Populate the commands

# First create the class to hold and dynamically create the commands
stoolbox_cli = SToolBoxCLI()

from sisl_toolbox.transiesta.poisson.fftpoisson_fix import fftpoisson_fix_cli

stoolbox_cli.register(fftpoisson_fix_cli)

from sisl_toolbox.siesta.atom._atom import atom_plot_cli

stoolbox_cli.register(atom_plot_cli)
54 changes: 54 additions & 0 deletions src/sisl_toolbox/cli/_cli_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Classes that hold information regarding how a given parameter should behave in a CLI
# They are meant to be used as metadata for the type annotations. That is, passing them
# to Annotated. E.g.: Annotated[int, CLIArgument(option="some_option")]. Even if they
# are empty, they indicate whether to treat the parameter as an argument or an option.
class CLIArgument:
def __init__(self, **kwargs):
self.kwargs = kwargs

class CLIOption:
def __init__(self, *param_decls: str, **kwargs):
if len(param_decls) > 0:
kwargs["param_decls"] = param_decls
self.kwargs = kwargs

def get_params_help(func) -> dict:
"""Gets the text help of parameters from the docstring"""
params_help = {}

in_parameters = False
read_key = None
arg_content = ""

for line in func.__doc__.split("\n"):
if "Parameters" in line:
in_parameters = True
space = line.find("Parameters")
continue

if in_parameters:
if len(line) < space + 1:
continue
if len(line) > 1 and line[0] != " ":
break

if line[space] not in (" ", "-"):
if read_key is not None:
params_help[read_key] = arg_content

read_key = line.split(":")[0].strip()
arg_content = ""
else:
if arg_content == "":
arg_content = line.strip()
arg_content = arg_content[0].upper() + arg_content[1:]
else:
arg_content += " " + line.strip()

if line.startswith("------"):
break

if read_key is not None:
params_help[read_key] = arg_content

return params_help
112 changes: 112 additions & 0 deletions src/sisl_toolbox/cli/_typer_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import typing
from typing_extensions import Annotated

from enum import Enum

import inspect
from copy import copy
import yaml

import typer

from ._cli_arguments import CLIArgument, CLIOption, get_params_help

def get_dict_param_kwargs(dict_annotation_args):

def yaml_dict(d: str):

if isinstance(d, dict):
return d

return yaml.safe_load(d)

argument_kwargs = {"parser": yaml_dict}

if len(dict_annotation_args) == 2:
try:
argument_kwargs["metavar"] = f"YAML_DICT[{dict_annotation_args[0].__name__}: {dict_annotation_args[1].__name__}]"
except:

Check notice

Code scanning / CodeQL

Except block handles 'BaseException' Note

Except block directly handles BaseException.
argument_kwargs["metavar"] = f"YAML_DICT[{dict_annotation_args[0]}: {dict_annotation_args[1]}]"

return argument_kwargs

# This dictionary keeps the kwargs that should be passed to typer arguments/options
# for a given type. This is for example to be used for types that typer does not
# have built in support for.
_CUSTOM_TYPE_KWARGS = {
dict: get_dict_param_kwargs,
}

def _get_custom_type_kwargs(type_):

if hasattr(type_, "__metadata__"):
type_ = type_.__origin__

if typing.get_origin(type_) is not None:
args = typing.get_args(type_)
type_ = typing.get_origin(type_)
else:
args = ()

try:
argument_kwargs = _CUSTOM_TYPE_KWARGS.get(type_, {})
if callable(argument_kwargs):
argument_kwargs = argument_kwargs(args)
except:

Check notice

Code scanning / CodeQL

Except block handles 'BaseException' Note

Except block directly handles BaseException.
argument_kwargs = {}

return argument_kwargs


def annotate_typer(func):
"""Annotates a function for a typer app.

It returns a new function, the original function is not modified.
"""
# Get the help message for all parameters found at the docstring
params_help = get_params_help(func)

# Get the original signature of the function
sig = inspect.signature(func)

# Loop over parameters in the signature, modifying them to include the
# typer info.
new_parameters = []
for param in sig.parameters.values():

argument_kwargs = _get_custom_type_kwargs(param.annotation)

default = param.default
if isinstance(param.default, Enum):
default = default.value

typer_arg_cls = typer.Argument if param.default == inspect.Parameter.empty else typer.Option
if hasattr(param.annotation, "__metadata__"):
for meta in param.annotation.__metadata__:
if isinstance(meta, CLIArgument):
typer_arg_cls = typer.Argument
argument_kwargs.update(meta.kwargs)
elif isinstance(meta, CLIOption):
typer_arg_cls = typer.Option
argument_kwargs.update(meta.kwargs)

if "param_decls" in argument_kwargs:
argument_args = argument_kwargs.pop("param_decls")
else:
argument_args = []

new_parameters.append(
param.replace(
default=default,
annotation=Annotated[param.annotation, typer_arg_cls(*argument_args, help=params_help.get(param.name), **argument_kwargs)]
)
)

# Create a copy of the function and update it with the modified signature.
# Also remove parameters documentation from the docstring.
annotated_func = copy(func)

annotated_func.__signature__ = sig.replace(parameters=new_parameters)
annotated_func.__doc__ = func.__doc__[:func.__doc__.find("Parameters\n")]

return annotated_func
95 changes: 48 additions & 47 deletions src/sisl_toolbox/siesta/atom/_atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@
which will show 4 plots for different sections. A command-line tool is also
made available through the `stoolbox`.
"""
from typing import Optional, List
from typing_extensions import Annotated

import sys
from collections.abc import Iterable
from enum import Enum
from functools import reduce
from pathlib import Path

Expand All @@ -33,7 +37,9 @@
import sisl as si
from sisl.utils import NotNonePropertyDict, PropertyDict

__all__ = ["AtomInput", "atom_plot_cli"]
from sisl_toolbox.cli._cli_arguments import CLIArgument, CLIOption

__all__ = ["AtomInput", "atom_plot"]


_script = Path(sys.argv[0]).name
Expand Down Expand Up @@ -737,63 +743,58 @@ def next_rc(ir, ic, nrows, ncols):
return fig, axs


def atom_plot_cli(subp=None):
""" Run plotting command for the output of atom """

is_sub = not subp is None

title = "Plotting facility for atom output (run in the atom output directory)"
if is_sub:
global _script
_script = f"{_script} atom-plot"
p = subp.add_parser("atom-plot", description=title, help=title)
else:
import argparse
p = argparse.ArgumentParser(title)

p.add_argument("--plot", '-P', action='append', type=str,
choices=('wavefunction', 'charge', 'log', 'potential'),
help="""Determine what to plot""")

p.add_argument("-l", default='spdf', type=str,
help="""Which l shells to plot""")

p.add_argument("--save", "-S", default=None,
help="""Save output plots to file.""")

p.add_argument("--show", default=False, action='store_true',
help="""Force showing the plot (only if --save is specified)""")
class AtomPlotOption(Enum):
"""Plotting options for atom"""
wavefunction = 'wavefunction'
charge = 'charge'
log = 'log'
potential = 'potential'

p.add_argument("input", type=str, default="INP",
help="""Input file name (default INP)""")
def atom_plot(
plot: Annotated[Optional[List[AtomPlotOption]], CLIArgument()] = None,
input: Path = Path("INP"),
l: str = 'spdf',
save: Annotated[Optional[str], CLIOption("-S", "--save")] = None,
show: bool = False
):
"""Plotting facility for atom output (run in the atom output directory)

if is_sub:
p.set_defaults(runner=atom_plot)
else:
atom_plot(p.parse_args())


def atom_plot(args):
Parameters
----------
plot:
Determine what to plot. If None is given, it plots everything.
input:
Input file name.
l:
Which l shells to plot.
save:
Save output plots to file.
show:
Force showing the plot.
"""
import matplotlib.pyplot as plt

input = Path(args.input)
atom = AtomInput.from_input(input)
input_path = Path(input)
atom = AtomInput.from_input(input_path)

# If the specified input is a file, use the parent
# Otherwise use the input *as is*.
if input.is_file():
path = input.parent
if input_path.is_file():
path = input_path.parent
else:
path = input
path = input_path

# if users have not specified what to plot, we plot everything
if args.plot is None:
args.plot = ('wavefunction', 'charge', 'log', 'potential')
fig = atom.plot(path, plot=args.plot, l=args.l, show=False)[0]
if plot is None:
plots = [p.value for p in AtomPlotOption]
else:
plots = [p.value if isinstance(p, AtomPlotOption) else p for p in plot ]

if args.save is None:
fig = atom.plot(path, plot=plots, l=l, show=False)[0]

if save is None:
plt.show()
else:
fig.savefig(args.save)
if args.show:
fig.savefig(save)
if show:
plt.show()
Loading
Loading