Skip to content

Commit

Permalink
Merge pull request #51 from Arcadia-Science/das/multisave
Browse files Browse the repository at this point in the history
Add ability to save multiple different file types
  • Loading branch information
mezarque authored Aug 5, 2024
2 parents c06e142 + 34e89ef commit 73fca02
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 25 deletions.
42 changes: 39 additions & 3 deletions arcadia_pycolor/mpl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
from pathlib import Path
from typing import Any, Literal, Union, cast

import matplotlib as mpl
import matplotlib.font_manager as font_manager
import matplotlib.pyplot as plt
from matplotlib import colormaps as mpl_colormaps
from matplotlib.axis import XAxis, YAxis
from matplotlib.backend_bases import FigureCanvasBase
from matplotlib.legend import Legend
from matplotlib.lines import Line2D
from matplotlib.offsetbox import DrawingArea
Expand Down Expand Up @@ -62,11 +64,45 @@ def _arcadia_fonts_found() -> bool:
return len(arcadia_fonts) > 0


def save_figure(fname: str, context: str = "web", **savefig_kwargs: dict[Any, Any]) -> None:
"Save the current figure with the default settings for web."
def save_figure(
fname: str,
filetypes: Union[list[str], None] = None,
context: str = "web",
**savefig_kwargs: dict[Any, Any],
) -> None:
"""
Save the current figure, accounting for Arcadia's defaults.
Args:
fname (str): the filename to save the figure to
filetypes (list, optional): the filetypes(s) to save the figure to.
If None, the original filetype of fname is used.
If the original filetype is not in filetypes, it is appended to the list.
context (str): the context to save the figure in, either 'web' or 'print'
**savefig_kwargs: additional keyword arguments to pass to plt.savefig
"""
kwargs = SAVEFIG_KWARGS_WEB if context == "web" else SAVEFIG_KWARGS_PRINT
kwargs.update(**savefig_kwargs) # type: ignore
plt.savefig(fname=fname, **kwargs) # type: ignore

# Gets a list of valid filetypes for saving figures from matplotlib.
valid_filetypes = list(FigureCanvasBase.get_supported_filetypes().keys())

filetype = Path(fname).suffix[1:]
filepath_no_filetype = Path(fname).with_suffix("")

if filetypes is None:
if not filetype:
raise ValueError("The filename must include a filetype if no filetypes are provided.")
filetypes = [filetype]
else:
filetypes.append(filetype)

for ftype in filetypes:
if ftype not in valid_filetypes:
print(f"Invalid filetype '{ftype}'. Skipping.")
continue

plt.savefig(fname=f"{filepath_no_filetype}.{ftype}", **kwargs) # type: ignore


def set_yticklabel_font(
Expand Down
59 changes: 59 additions & 0 deletions arcadia_pycolor/tests/test_mpl_save_figure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import matplotlib.pyplot as plt
import pytest

import arcadia_pycolor as apc


def simple_plot():
plt.figure(figsize=(3, 3))
plt.plot([1, 2, 3], [1, 2, 3])


@pytest.mark.parametrize(
"fname, filetypes, expected_outputs",
[
("test.pdf", None, ["test.pdf"]),
("test.pdf", ["pdf"], ["test.pdf"]),
("test.pdf", ["pdf", "png"], ["test.pdf", "test.png"]),
("test.pdf", ["png", "eps"], ["test.pdf", "test.png", "test.eps"]),
("test.pdf", ["png"], ["test.pdf", "test.png"]),
("test", ["pdf"], ["test.pdf"]),
("test", ["pdf", "png"], ["test.pdf", "test.png"]),
],
)
def test_mpl_save_figure_filetype_examples(tmp_path, fname, filetypes, expected_outputs):
"""
Test the `mpl.save_figure` function with various suffixes.
"""

simple_plot()
apc.mpl.save_figure(fname=tmp_path / fname, filetypes=filetypes)

for output in expected_outputs:
output_path = tmp_path / output
assert output_path.is_file()


@pytest.mark.parametrize(
"fname, filetypes",
[
("test.pdf", ["invalid"]),
("test", ["invalid"]),
("test.invalid", ["pdf"]),
("test.invalid", None),
],
)
def test_mpl_save_figure_filetype_invalid(tmp_path, fname, filetypes, capsys):
simple_plot()

apc.mpl.save_figure(fname=tmp_path / fname, filetypes=filetypes)

captured = capsys.readouterr()
assert "Invalid filetype 'invalid'. Skipping." in captured.out


def test_mpl_save_figure_no_filetype(tmp_path):
simple_plot()

with pytest.raises(ValueError):
apc.mpl.save_figure(fname=tmp_path / "test")
64 changes: 42 additions & 22 deletions docs/style_usage.ipynb

Large diffs are not rendered by default.

0 comments on commit 73fca02

Please sign in to comment.