Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes from #216 #227

Merged
merged 2 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ jobs:
python -m pip install -e '.[tests]'

- name: Test
run: pytest
run: pytest
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repos:
rev: 22.6.0
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
- repo: https://github.com/pycqa/flake8
rev: 3.9.2
hooks:
- id: flake8
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"h5py",
"pymatgen>=2020.10.20",
"phonopy>=2.1.3",
"matplotlib",
"matplotlib>=3.2.0",
"seekpath",
"castepxbin<1.0",
"colormath",
Expand Down
34 changes: 10 additions & 24 deletions sumo/cli/bandplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
from importlib.resources import files as ilr_files
except ImportError: # Python < 3.9
from importlib_resources import files as ilr_files

import matplotlib as mpl
from pymatgen.electronic_structure.bandstructure import (
get_reconstructed_band_structure,
)
from pymatgen.electronic_structure.bandstructure import get_reconstructed_band_structure
from pymatgen.electronic_structure.core import Spin
from pymatgen.io.vasp.outputs import BSVasprun

Expand Down Expand Up @@ -394,9 +393,7 @@ def bandplot(
else:
logging.info(f"Found PDOS file {pdos_file}")
else:
logging.info(
f"Cell file {cell_file} does not exist, cannot plot PDOS."
)
logging.info(f"Cell file {cell_file} does not exist, cannot plot PDOS.")

dos, pdos = read_castep_dos(
dos_file,
Expand Down Expand Up @@ -620,8 +617,7 @@ def _get_parser():
"-c",
"--code",
default="vasp",
help="Electronic structure code (default: vasp)."
'"questaal" also supported.',
help="Electronic structure code (default: vasp)." '"questaal" also supported.',
)
parser.add_argument(
"-p", "--prefix", metavar="P", help="prefix for the files generated"
Expand Down Expand Up @@ -762,24 +758,20 @@ def _get_parser():
"--orbitals",
type=_el_orb,
metavar="O",
help=(
"orbitals to split into lm-decomposed "
'contributions (e.g. "Ru.d")'
),
help="orbitals to split into lm-decomposed contributions (e.g. 'Ru.d')",
)
parser.add_argument(
"--atoms",
type=_atoms,
metavar="A",
help=('atoms to include (e.g. "O.1.2.3,Ru.1.2.3")'),
help='atoms to include (e.g. "O.1.2.3,Ru.1.2.3")',
)
parser.add_argument(
"--spin",
type=str,
default=None,
help=(
"select only one spin channel for a "
"spin-polarised calculation "
"select only one spin channel for a spin-polarised calculation "
"(options: up, 1; down, -1)"
),
)
Expand Down Expand Up @@ -829,9 +821,7 @@ def _get_parser():
parser.add_argument(
"--height", type=float, default=None, help="height of the graph"
)
parser.add_argument(
"--width", type=float, default=None, help="width of the graph"
)
parser.add_argument("--width", type=float, default=None, help="width of the graph")
parser.add_argument(
"--ymin", type=float, default=-6.0, help="minimum energy on the y-axis"
)
Expand Down Expand Up @@ -883,18 +873,14 @@ def main():
logging.getLogger("").addHandler(console)

if args.config is None:
config_path = os.path.join(
ilr_files("sumo.plotting"), "orbital_colours.conf"
)
config_path = ilr_files("sumo.plotting") / "orbital_colours.conf"
else:
config_path = args.config
colours = configparser.ConfigParser()
colours.read(os.path.abspath(config_path))

warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
warnings.filterwarnings(
"ignore", category=UnicodeWarning, module="matplotlib"
)
warnings.filterwarnings("ignore", category=UnicodeWarning, module="matplotlib")
warnings.filterwarnings("ignore", category=UserWarning, module="pymatgen")

bandplot(
Expand Down
27 changes: 8 additions & 19 deletions sumo/cli/dosplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import matplotlib as mpl
import numpy as np

try:
from importlib.resources import files as ilr_files
except ImportError: # Python < 3.9
Expand Down Expand Up @@ -441,9 +442,7 @@ def _get_parser():
"--code",
default="vasp",
metavar="C",
help=(
'Input file format: "vasp" (vasprun.xml) or ' '"questaal" (opt.ext)'
),
help='Input file format: "vasp" (vasprun.xml) or "questaal" (opt.ext)',
)
parser.add_argument(
"-p", "--prefix", metavar="P", help="prefix for the files generated"
Expand All @@ -463,25 +462,21 @@ def _get_parser():
"--orbitals",
type=_el_orb,
metavar="O",
help=(
"orbitals to split into lm-decomposed "
'contributions (e.g. "Ru.d")'
),
help="orbitals to split into lm-decomposed contributions (e.g. 'Ru.d')",
)
parser.add_argument(
"-a",
"--atoms",
type=_atoms,
metavar="A",
help=('atoms to include (e.g. "O.1.2.3,Ru.1.2.3")'),
help='atoms to include (e.g. "O.1.2.3,Ru.1.2.3")',
)
parser.add_argument(
"--spin",
type=str,
default=None,
help=(
"select one spin channel only for a "
"spin-polarised calculation "
"select one spin channel only for a spin-polarised calculation "
"(options: up, 1; down, -1)"
),
)
Expand Down Expand Up @@ -560,9 +555,7 @@ def _get_parser():
parser.add_argument(
"--height", type=float, default=None, help="height of the graph"
)
parser.add_argument(
"--width", type=float, default=None, help="width of the graph"
)
parser.add_argument("--width", type=float, default=None, help="width of the graph")
parser.add_argument(
"--xmin", type=float, default=-6.0, help="minimum energy on the x-axis"
)
Expand Down Expand Up @@ -634,18 +627,14 @@ def main():
logging.getLogger("").addHandler(console)

if args.config is None:
config_path = os.path.join(
ilr_files("sumo.plotting"), "orbital_colours.conf"
)
config_path = ilr_files("sumo.plotting") / "orbital_colours.conf"
else:
config_path = args.config
colours = configparser.ConfigParser()
colours.read(os.path.abspath(config_path))

warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
warnings.filterwarnings(
"ignore", category=UnicodeWarning, module="matplotlib"
)
warnings.filterwarnings("ignore", category=UnicodeWarning, module="matplotlib")
warnings.filterwarnings("ignore", category=UserWarning, module="pymatgen")

if args.zero_energy is not None:
Expand Down
2 changes: 1 addition & 1 deletion sumo/io/castep.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def labels_from_cell(cell_file, phonon=False):
line = f.readline() # Skip past block start line
while blockend.match(line.lower()) is None:
# Do not parse break lines
if 'break' not in line.lower():
if "break" not in line.lower():
kpt = tuple(map(float, line.split()[:3]))
if len(line.split()) > 3:
label = line.split()[-1]
Expand Down
43 changes: 13 additions & 30 deletions sumo/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""
Subpackage providing helper functions for generating publication ready plots.
"""
from functools import wraps
import os
from functools import wraps

import matplotlib.pyplot
import numpy as np
Expand All @@ -19,15 +19,11 @@

colour_cache = {}

sumo_base_style = os.path.join(ilr_files("sumo.plotting"), "sumo_base.mplstyle")
sumo_dos_style = os.path.join(ilr_files("sumo.plotting"), "sumo_dos.mplstyle")
sumo_bs_style = os.path.join(ilr_files("sumo.plotting"), "sumo_bs.mplstyle")
sumo_phonon_style = os.path.join(
ilr_files("sumo.plotting"), "sumo_phonon.mplstyle"
)
sumo_optics_style = os.path.join(
ilr_files("sumo.plotting"), "sumo_optics.mplstyle"
)
sumo_base_style = ilr_files("sumo.plotting") / "sumo_base.mplstyle"
sumo_dos_style = ilr_files("sumo.plotting") / "sumo_dos.mplstyle"
sumo_bs_style = ilr_files("sumo.plotting") / "sumo_bs.mplstyle"
sumo_phonon_style = ilr_files("sumo.plotting") / "sumo_phonon.mplstyle"
sumo_optics_style = ilr_files("sumo.plotting") / "sumo_optics.mplstyle"


def styled_plot(*style_sheets):
Expand All @@ -47,9 +43,7 @@ def styled_plot(*style_sheets):

def decorator(get_plot):
@wraps(get_plot)
def wrapper(
*args, fonts=None, style=None, no_base_style=False, **kwargs
):
def wrapper(*args, fonts=None, style=None, no_base_style=False, **kwargs):
if no_base_style:
list_style = []
else:
Expand All @@ -62,9 +56,7 @@ def wrapper(
list_style += [style]

if fonts is not None:
list_style += [
{"font.family": "sans-serif", "font.sans-serif": fonts}
]
list_style += [{"font.family": "sans-serif", "font.sans-serif": fonts}]

matplotlib.pyplot.style.use(list_style)
return get_plot(*args, **kwargs)
Expand Down Expand Up @@ -277,9 +269,7 @@ def get_interpolated_colors(color1, color2, color3, weights, colorspace="lab"):
"xyz": XYZColor,
}
if colorspace not in list(colorspace_mapping.keys()):
raise ValueError(
f"colorspace must be one of {colorspace_mapping.keys()}"
)
raise ValueError(f"colorspace must be one of {colorspace_mapping.keys()}")

colorspace = colorspace_mapping[colorspace]

Expand All @@ -290,19 +280,13 @@ def get_interpolated_colors(color1, color2, color3, weights, colorspace="lab"):

# now convert to the colorspace basis for interpolation
basis1 = np.array(
convert_color(
color1_rgb, colorspace, target_illuminant="d50"
).get_value_tuple()
convert_color(color1_rgb, colorspace, target_illuminant="d50").get_value_tuple()
)
basis2 = np.array(
convert_color(
color2_rgb, colorspace, target_illuminant="d50"
).get_value_tuple()
convert_color(color2_rgb, colorspace, target_illuminant="d50").get_value_tuple()
)
basis3 = np.array(
convert_color(
color3_rgb, colorspace, target_illuminant="d50"
).get_value_tuple()
convert_color(color3_rgb, colorspace, target_illuminant="d50").get_value_tuple()
)

# ensure weights is a numpy array
Expand All @@ -317,8 +301,7 @@ def get_interpolated_colors(color1, color2, color3, weights, colorspace="lab"):

# convert colors to RGB
rgb_colors = [
convert_color(colorspace(*c), sRGBColor).get_value_tuple()
for c in colors
convert_color(colorspace(*c), sRGBColor).get_value_tuple() for c in colors
]

# ensure all rgb values are less than 1 (sometimes issues in interpolation
Expand Down
2 changes: 0 additions & 2 deletions sumo/plotting/optics_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np
import scipy.constants as scpc
from matplotlib import rcParams
from matplotlib.font_manager import FontProperties, findfont
from matplotlib.ticker import AutoMinorLocator, FuncFormatter, MaxNLocator

from sumo.plotting import (
Expand Down Expand Up @@ -242,7 +241,6 @@ def get_plot(
ax.set_ylim(ymin, ymax)

if spectrum_key == "absorption":
font = findfont(FontProperties(family=["sans-serif"]))
ax.yaxis.set_major_formatter(
FuncFormatter(curry_power_tick(times_sign=r"\times"))
)
Expand Down
7 changes: 2 additions & 5 deletions sumo/symmetry/brad_crack_kpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""

from json import load as load_json
import os

import numpy as np

Expand Down Expand Up @@ -66,9 +65,7 @@ def __init__(self, structure, symprec=1e-3, spg=None):
spg_symbol = self.spg_symbol
lattice_type = self.lattice_type

bravais = self._get_bravais_lattice(
spg_symbol, lattice_type, a, b, c, unique
)
bravais = self._get_bravais_lattice(spg_symbol, lattice_type, a, b, c, unique)
self._kpath = self._get_bradcrack_data(bravais)

@staticmethod
Expand All @@ -85,7 +82,7 @@ def _get_bradcrack_data(bravais):
'path': [['\Gamma', 'X', ..., 'P'], ['H', 'N', ...]]}

"""
json_file = os.path.join(ilr_files("sumo.symmetry"), "bradcrack.json")
json_file = ilr_files("sumo.symmetry") / "bradcrack.json"
with open(json_file) as f:
bradcrack_data = load_json(f)
return bradcrack_data[bravais]
Expand Down
16 changes: 5 additions & 11 deletions tests/tests_electronic_structure/test_optics.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
import json
import unittest
import os
import unittest

try:
from importlib.resources import files as ilr_files
except ImportError: # Python < 3.9
from importlib_resources import files as ilr_files

import numpy as np
from numpy.testing import assert_almost_equal
from pymatgen.io.vasp import Vasprun

from sumo.electronic_structure.optics import (
calculate_dielectric_properties,
kkr,
)
from sumo.electronic_structure.optics import calculate_dielectric_properties, kkr


class AbsorptionTestCase(unittest.TestCase):
def setUp(self):
diel_path = os.path.join(
ilr_files("tests"), "data", "Ge", "ge_diel.json"
)
diel_path = os.path.join(ilr_files("tests"), "data", "Ge", "ge_diel.json")
with open(diel_path) as f:
self.ge_diel = json.load(f)

Expand All @@ -35,9 +31,7 @@ def test_absorption(self):
self.ge_diel,
{"absorption"},
)
self.assertIsNone(
assert_almost_equal(properties["absorption"], self.ge_abs)
)
self.assertIsNone(assert_almost_equal(properties["absorption"], self.ge_abs))


class KramersKronigTestCase(unittest.TestCase):
Expand Down
Loading