Skip to content

Commit

Permalink
BUG: fix a regression in (AxisAligned)SlicePlot's API, add support fo…
Browse files Browse the repository at this point in the history
…r array-like normal arguments in (AxisAligned)ProjectionPlot
  • Loading branch information
neutrinoceros committed Jan 10, 2022
1 parent 7d8cbd4 commit 70cb913
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
5 changes: 3 additions & 2 deletions yt/visualization/plot_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,7 +1549,7 @@ def _validate_init_args(*, normal, axis, fields) -> None:
raise TypeError("Received incompatible arguments 'axis' and 'normal'")
normal = axis

if (normal, fields) == (None, None):
if normal is fields is None:
raise TypeError(
"missing 2 required positional arguments: 'normal' and 'fields'"
)
Expand Down Expand Up @@ -1939,7 +1939,7 @@ def __init__(
axis=axis,
fields=fields,
)

normal = self.sanitize_normal_vector(ds, normal)
# this will handle time series data and controllers
axis = fix_axis(normal, ds)
(bounds, center, display_center) = get_window_parameters(
Expand Down Expand Up @@ -2160,6 +2160,7 @@ def __init__(
):
# TODO: in yt 4.2, remove default values for normal and fields, drop axis kwarg
normal = self._validate_init_args(normal=normal, fields=fields, axis=axis)
normal = self.sanitize_normal_vector(ds, normal)

axis = fix_axis(normal, ds)
if ds.geometry in (
Expand Down
12 changes: 12 additions & 0 deletions yt/visualization/tests/test_normal_plot_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from itertools import product

import numpy as np
import pytest

from yt._maintenance.deprecation import VisibleDeprecationWarning
Expand Down Expand Up @@ -78,3 +81,12 @@ def test_error_with_missing_fields_with_positional(ds, plot_cls):
TypeError, match="missing required positional argument: 'fields'"
):
plot_cls(ds, "z")


@pytest.mark.parametrize(
"plot_cls, normal",
product([SlicePlot, ProjectionPlot], [(0, 0, 1), [0, 0, 1], np.array((0, 0, 1))]),
)
def test_normalplot_normal_array(ds, plot_cls, normal):
# see regression https://github.com/yt-project/yt/issues/3736
plot_cls(ds, normal, fields=("stream", "Density"))

0 comments on commit 70cb913

Please sign in to comment.