Skip to content

Commit

Permalink
Fixes from #216
Browse files Browse the repository at this point in the history
  • Loading branch information
utf committed Oct 11, 2023
1 parent 6e05b66 commit 4ea323b
Show file tree
Hide file tree
Showing 17 changed files with 65 additions and 137 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ jobs:
python -m pip install -e '.[tests]'
- name: Test
run: pytest
run: pytest
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repos:
rev: 22.6.0
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
- repo: https://github.com/pycqa/flake8
rev: 3.9.2
hooks:
- id: flake8
Expand Down
20 changes: 6 additions & 14 deletions sumo/cli/bandplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
from importlib.resources import files as ilr_files
except ImportError: # Python < 3.9
from importlib_resources import files as ilr_files

import matplotlib as mpl
from pymatgen.electronic_structure.bandstructure import (
get_reconstructed_band_structure,
)
from pymatgen.electronic_structure.bandstructure import get_reconstructed_band_structure
from pymatgen.electronic_structure.core import Spin
from pymatgen.io.vasp.outputs import BSVasprun

Expand Down Expand Up @@ -394,9 +393,7 @@ def bandplot(
else:
logging.info(f"Found PDOS file {pdos_file}")
else:
logging.info(
f"Cell file {cell_file} does not exist, cannot plot PDOS."
)
logging.info(f"Cell file {cell_file} does not exist, cannot plot PDOS.")

dos, pdos = read_castep_dos(
dos_file,
Expand Down Expand Up @@ -620,8 +617,7 @@ def _get_parser():
"-c",
"--code",
default="vasp",
help="Electronic structure code (default: vasp)."
'"questaal" also supported.',
help="Electronic structure code (default: vasp)." '"questaal" also supported.',
)
parser.add_argument(
"-p", "--prefix", metavar="P", help="prefix for the files generated"
Expand Down Expand Up @@ -825,9 +821,7 @@ def _get_parser():
parser.add_argument(
"--height", type=float, default=None, help="height of the graph"
)
parser.add_argument(
"--width", type=float, default=None, help="width of the graph"
)
parser.add_argument("--width", type=float, default=None, help="width of the graph")
parser.add_argument(
"--ymin", type=float, default=-6.0, help="minimum energy on the y-axis"
)
Expand Down Expand Up @@ -886,9 +880,7 @@ def main():
colours.read(os.path.abspath(config_path))

warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
warnings.filterwarnings(
"ignore", category=UnicodeWarning, module="matplotlib"
)
warnings.filterwarnings("ignore", category=UnicodeWarning, module="matplotlib")
warnings.filterwarnings("ignore", category=UserWarning, module="pymatgen")

bandplot(
Expand Down
9 changes: 3 additions & 6 deletions sumo/cli/dosplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import matplotlib as mpl
import numpy as np

try:
from importlib.resources import files as ilr_files
except ImportError: # Python < 3.9
Expand Down Expand Up @@ -554,9 +555,7 @@ def _get_parser():
parser.add_argument(
"--height", type=float, default=None, help="height of the graph"
)
parser.add_argument(
"--width", type=float, default=None, help="width of the graph"
)
parser.add_argument("--width", type=float, default=None, help="width of the graph")
parser.add_argument(
"--xmin", type=float, default=-6.0, help="minimum energy on the x-axis"
)
Expand Down Expand Up @@ -635,9 +634,7 @@ def main():
colours.read(os.path.abspath(config_path))

warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
warnings.filterwarnings(
"ignore", category=UnicodeWarning, module="matplotlib"
)
warnings.filterwarnings("ignore", category=UnicodeWarning, module="matplotlib")
warnings.filterwarnings("ignore", category=UserWarning, module="pymatgen")

if args.zero_energy is not None:
Expand Down
2 changes: 1 addition & 1 deletion sumo/io/castep.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def labels_from_cell(cell_file, phonon=False):
line = f.readline() # Skip past block start line
while blockend.match(line.lower()) is None:
# Do not parse break lines
if 'break' not in line.lower():
if "break" not in line.lower():
kpt = tuple(map(float, line.split()[:3]))
if len(line.split()) > 3:
label = line.split()[-1]
Expand Down
29 changes: 8 additions & 21 deletions sumo/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""
Subpackage providing helper functions for generating publication ready plots.
"""
from functools import wraps
import os
from functools import wraps

import matplotlib.pyplot
import numpy as np
Expand Down Expand Up @@ -43,9 +43,7 @@ def styled_plot(*style_sheets):

def decorator(get_plot):
@wraps(get_plot)
def wrapper(
*args, fonts=None, style=None, no_base_style=False, **kwargs
):
def wrapper(*args, fonts=None, style=None, no_base_style=False, **kwargs):
if no_base_style:
list_style = []
else:
Expand All @@ -58,9 +56,7 @@ def wrapper(
list_style += [style]

if fonts is not None:
list_style += [
{"font.family": "sans-serif", "font.sans-serif": fonts}
]
list_style += [{"font.family": "sans-serif", "font.sans-serif": fonts}]

matplotlib.pyplot.style.use(list_style)
return get_plot(*args, **kwargs)
Expand Down Expand Up @@ -273,9 +269,7 @@ def get_interpolated_colors(color1, color2, color3, weights, colorspace="lab"):
"xyz": XYZColor,
}
if colorspace not in list(colorspace_mapping.keys()):
raise ValueError(
f"colorspace must be one of {colorspace_mapping.keys()}"
)
raise ValueError(f"colorspace must be one of {colorspace_mapping.keys()}")

colorspace = colorspace_mapping[colorspace]

Expand All @@ -286,19 +280,13 @@ def get_interpolated_colors(color1, color2, color3, weights, colorspace="lab"):

# now convert to the colorspace basis for interpolation
basis1 = np.array(
convert_color(
color1_rgb, colorspace, target_illuminant="d50"
).get_value_tuple()
convert_color(color1_rgb, colorspace, target_illuminant="d50").get_value_tuple()
)
basis2 = np.array(
convert_color(
color2_rgb, colorspace, target_illuminant="d50"
).get_value_tuple()
convert_color(color2_rgb, colorspace, target_illuminant="d50").get_value_tuple()
)
basis3 = np.array(
convert_color(
color3_rgb, colorspace, target_illuminant="d50"
).get_value_tuple()
convert_color(color3_rgb, colorspace, target_illuminant="d50").get_value_tuple()
)

# ensure weights is a numpy array
Expand All @@ -313,8 +301,7 @@ def get_interpolated_colors(color1, color2, color3, weights, colorspace="lab"):

# convert colors to RGB
rgb_colors = [
convert_color(colorspace(*c), sRGBColor).get_value_tuple()
for c in colors
convert_color(colorspace(*c), sRGBColor).get_value_tuple() for c in colors
]

# ensure all rgb values are less than 1 (sometimes issues in interpolation
Expand Down
2 changes: 0 additions & 2 deletions sumo/plotting/optics_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np
import scipy.constants as scpc
from matplotlib import rcParams
from matplotlib.font_manager import FontProperties, findfont
from matplotlib.ticker import AutoMinorLocator, FuncFormatter, MaxNLocator

from sumo.plotting import (
Expand Down Expand Up @@ -242,7 +241,6 @@ def get_plot(
ax.set_ylim(ymin, ymax)

if spectrum_key == "absorption":
font = findfont(FontProperties(family=["sans-serif"]))
ax.yaxis.set_major_formatter(
FuncFormatter(curry_power_tick(times_sign=r"\times"))
)
Expand Down
5 changes: 1 addition & 4 deletions sumo/symmetry/brad_crack_kpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""

from json import load as load_json
import os

import numpy as np

Expand Down Expand Up @@ -66,9 +65,7 @@ def __init__(self, structure, symprec=1e-3, spg=None):
spg_symbol = self.spg_symbol
lattice_type = self.lattice_type

bravais = self._get_bravais_lattice(
spg_symbol, lattice_type, a, b, c, unique
)
bravais = self._get_bravais_lattice(spg_symbol, lattice_type, a, b, c, unique)
self._kpath = self._get_bradcrack_data(bravais)

@staticmethod
Expand Down
16 changes: 5 additions & 11 deletions tests/tests_electronic_structure/test_optics.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
import json
import unittest
import os
import unittest

try:
from importlib.resources import files as ilr_files
except ImportError: # Python < 3.9
from importlib_resources import files as ilr_files

import numpy as np
from numpy.testing import assert_almost_equal
from pymatgen.io.vasp import Vasprun

from sumo.electronic_structure.optics import (
calculate_dielectric_properties,
kkr,
)
from sumo.electronic_structure.optics import calculate_dielectric_properties, kkr


class AbsorptionTestCase(unittest.TestCase):
def setUp(self):
diel_path = os.path.join(
ilr_files("tests"), "data", "Ge", "ge_diel.json"
)
diel_path = os.path.join(ilr_files("tests"), "data", "Ge", "ge_diel.json")
with open(diel_path) as f:
self.ge_diel = json.load(f)

Expand All @@ -35,9 +31,7 @@ def test_absorption(self):
self.ge_diel,
{"absorption"},
)
self.assertIsNone(
assert_almost_equal(properties["absorption"], self.ge_abs)
)
self.assertIsNone(assert_almost_equal(properties["absorption"], self.ge_abs))


class KramersKronigTestCase(unittest.TestCase):
Expand Down
Loading

0 comments on commit 4ea323b

Please sign in to comment.