Skip to content

Commit

Permalink
Added set_xlabel(), set_ylabel(), set_title(), set_xscale(), …
Browse files Browse the repository at this point in the history
…`set_yscale()`, and `set_aspect()` functions to the `named_arrays.plt` subpackage. (#78)
  • Loading branch information
byrdie authored Oct 11, 2024
1 parent 2ae5f31 commit 0574b22
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 1 deletion.
36 changes: 36 additions & 0 deletions named_arrays/_scalars/scalar_named_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
na.plt.plot,
na.plt.fill,
)
PLT_AXES_SETTERS = (
na.plt.set_xlabel,
na.plt.set_ylabel,
na.plt.set_title,
na.plt.set_xscale,
na.plt.set_yscale,
na.plt.set_aspect,
)
NDFILTER_FUNCTIONS = (
na.ndfilters.mean_filter,
na.ndfilters.trimmed_mean_filter,
Expand Down Expand Up @@ -844,6 +852,34 @@ def plt_text(
return result


def plt_axes_setter(
method: str,
*args,
ax: None | matplotlib.axes.Axes | na.AbstractScalarArray = None,
**kwargs,
) -> None:

if ax is None:
ax = plt.gca()

try:
args = [scalars._normalize(arg) for arg in args]
ax = scalars._normalize(ax)
kwargs = {k: scalars._normalize(kwargs[k]) for k in kwargs}
except na.ScalarTypeError: # pragma: nocover
return NotImplemented

shape = ax.shape

args = [arg.broadcast_to(shape) for arg in args]
kwargs = {k: kwargs[k].broadcast_to(shape) for k in kwargs}

for index in na.ndindex(shape):
args_index = [arg[index].ndarray for arg in args]
kwargs_index = {k: kwargs[k][index].ndarray for k in kwargs}
getattr(ax[index].ndarray, method.__name__)(*args_index, **kwargs_index)


@_implements(na.jacobian)
def jacobian(
function: Callable[[na.AbstractScalar], na.AbstractScalar],
Expand Down
3 changes: 3 additions & 0 deletions named_arrays/_scalars/scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,9 @@ def __named_array_function__(self, func, *args, **kwargs):
if func in scalar_named_array_functions.PLT_PLOT_LIKE_FUNCTIONS:
return scalar_named_array_functions.plt_plot_like(func, *args, **kwargs)

if func in scalar_named_array_functions.PLT_AXES_SETTERS:
return scalar_named_array_functions.plt_axes_setter(func, *args, **kwargs)

if func in scalar_named_array_functions.NDFILTER_FUNCTIONS:
return scalar_named_array_functions.ndfilter(func, *args, **kwargs)

Expand Down
146 changes: 145 additions & 1 deletion named_arrays/plt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@
"pcolormovie",
"text",
"brace_vertical",
"set_xlabel",
"set_ylabel",
"set_title",
"set_xscale",
"set_yscale",
"set_aspect",
]


def subplots(
axis_rows: str = "subplots_row",
ncols: int = 1,
axis_cols: str = "subplots_col",
ncols: int = 1,
nrows: int = 1,
*,
sharex: bool | Literal["none", "all", "row", "col"] = False,
Expand Down Expand Up @@ -856,6 +862,144 @@ def text(
)


def set_xlabel(
xlabel: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
**kwargs,
) -> na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.set_xlabel` for named arrays.
Parameters
----------
xlabel
The horizontal axis label for each axis.
ax
The matplotlib axes instance on which to apply the label.
"""
return na._named_array_function(
set_xlabel,
xlabel=na.as_named_array(xlabel),
ax=ax,
**kwargs,
)


def set_ylabel(
ylabel: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
**kwargs,
) -> na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.set_ylabel` for named arrays.
Parameters
----------
ylabel
The vertical axis label for each axis.
ax
The matplotlib axes instance on which to apply the label.
"""
return na._named_array_function(
set_ylabel,
ylabel=na.as_named_array(ylabel),
ax=ax,
**kwargs,
)


def set_title(
label: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
**kwargs,
) -> na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.set_title` for named arrays.
Parameters
----------
label
The title for each axis.
ax
The matplotlib axes instance on which to apply the label.
"""
return na._named_array_function(
set_title,
label=na.as_named_array(label),
ax=ax,
**kwargs,
)


def set_xscale(
value: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
**kwargs,
) -> na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.set_xscale` for named arrays.
Parameters
----------
value
The scale type to apply to the horizontal scale of each axis.
ax
The matplotlib axes instance on which to apply the label.
"""
return na._named_array_function(
set_xscale,
value=na.as_named_array(value),
ax=ax,
**kwargs,
)


def set_yscale(
value: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
**kwargs,
) -> na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.set_yscale` for named arrays.
Parameters
----------
value
The scale type to apply to the vertical scale of each axis.
ax
The matplotlib axes instance on which to apply the label.
"""
return na._named_array_function(
set_yscale,
value=na.as_named_array(value),
ax=ax,
**kwargs,
)


def set_aspect(
aspect: float | str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
**kwargs,
) -> na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.set_aspect` for named arrays.
Parameters
----------
aspect
The aspect ratio to apply to each axis
ax
The matplotlib axes instance on which to apply the label.
"""
return na._named_array_function(
set_aspect,
aspect=na.as_named_array(aspect),
ax=ax,
**kwargs,
)


def brace_vertical(
x: float | u.Quantity | na.AbstractScalar,
width: float | u.Quantity | na.AbstractScalar,
Expand Down
86 changes: 86 additions & 0 deletions named_arrays/tests/test_plt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import matplotlib.axes
import matplotlib.animation
import named_arrays as na

Expand Down Expand Up @@ -43,3 +44,88 @@ def test_pcolormovie(
)
assert isinstance(result, matplotlib.animation.FuncAnimation)
assert isinstance(result.to_jshtml(), str)


@pytest.mark.parametrize(
argnames="xlabel,ax",
argvalues=[
("foo", None),
("foo", na.plt.subplots(ncols=3)[1]),
]
)
def test_set_xlabel(
xlabel: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar,
):
na.plt.set_xlabel(xlabel, ax=ax)


@pytest.mark.parametrize(
argnames="ylabel,ax",
argvalues=[
("foo", None),
("foo", na.plt.subplots(ncols=3)[1]),
]
)
def test_set_ylabel(
ylabel: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar,
):
na.plt.set_ylabel(ylabel, ax=ax)


@pytest.mark.parametrize(
argnames="label,ax",
argvalues=[
("foo", None),
("foo", na.plt.subplots(ncols=3)[1]),
]
)
def test_set_title(
label: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar,
):
na.plt.set_title(label, ax=ax)


@pytest.mark.parametrize(
argnames="value,ax",
argvalues=[
("log", None),
("log", na.plt.subplots(ncols=3)[1]),
]
)
def test_set_xscale(
value: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar,
):
na.plt.set_xscale(value, ax=ax)


@pytest.mark.parametrize(
argnames="value,ax",
argvalues=[
("log", None),
("log", na.plt.subplots(ncols=3)[1]),
]
)
def test_set_yscale(
value: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar,
):
na.plt.set_yscale(value, ax=ax)


@pytest.mark.parametrize(
argnames="aspect,ax",
argvalues=[
("equal", None),
("equal", na.plt.subplots(ncols=3)[1]),
(2, na.plt.subplots(ncols=3)[1]),
]
)
def test_set_aspect(
aspect: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar,
):
na.plt.set_aspect(aspect, ax=ax)

0 comments on commit 0574b22

Please sign in to comment.