Skip to content

Commit

Permalink
FEAT: refactor ProjectionPlot into a factor function similar to Slice…
Browse files Browse the repository at this point in the history
…Plot, introducing the AxisAlignedProjectionPlot class. Add snake case aliases projection_plot -> ProjectionPlot and slice_plot -> SlicePlot, and update SlicePlot's docstring
  • Loading branch information
neutrinoceros committed Jul 24, 2021
1 parent dd3f795 commit ce567af
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 58 deletions.
2 changes: 2 additions & 0 deletions yt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,10 @@
apply_colormap,
make_colormap,
plot_2d,
projection_plot,
scale_image,
show_colormaps,
slice_plot,
write_bitmap,
write_image,
write_projection,
Expand Down
2 changes: 2 additions & 0 deletions yt/visualization/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
ProjectionPlot,
SlicePlot,
plot_2d,
projection_plot,
slice_plot,
)
from .profile_plotter import PhasePlot, ProfilePlot
from .streamlines import Streamlines
202 changes: 144 additions & 58 deletions yt/visualization/plot_window.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict
from functools import wraps
from numbers import Number
from typing import Union

import matplotlib
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -1628,7 +1629,7 @@ def __init__(
self.set_axes_unit(axes_unit)


class ProjectionPlot(PWViewerMPL):
class AxisAlignedProjectionPlot(PWViewerMPL):
r"""Creates a projection plot from a dataset
Given a ds object, an axis to project along, and a field name
Expand Down Expand Up @@ -1772,7 +1773,7 @@ class ProjectionPlot(PWViewerMPL):
>>> from yt import load
>>> ds = load("IsolateGalaxygalaxy0030/galaxy0030")
>>> p = ProjectionPlot(ds, "z", ("gas", "density"), width=(20, "kpc"))
>>> p = AxisAlignedProjectionPlot(ds, "z", ("gas", "density"), width=(20, "kpc"))
"""
_plot_type = "Projection"
Expand Down Expand Up @@ -2293,14 +2294,61 @@ def _create_axes(self, axrect):
self.axes = self.figure.add_axes(axrect, projection=self._projection)


def SlicePlot(ds, normal=None, fields=None, axis=None, *args, **kwargs):
def _sanitize_normal_vector(ds, normal) -> Union[str, np.ndarray]:
"""Return the name of a cartesian axis whener possible,
or a 3-element 1D ndarray of float64 in any other valid case.
Fail with a descriptive error message otherwise.
"""
axis_names = ds.coordinates.axis_order

if isinstance(normal, str):
if normal not in axis_names:
raise ValueError(
f"'{normal}' is not a valid axis name. " "Expected either ."
)
return normal

if isinstance(normal, (int, np.integer)):
if normal not in (0, 1, 2):
raise ValueError(
f"{normal} is not a valid axis identifier. Expected either 0, 1, or 2."
)
return axis_names[normal]

if not is_sequence(normal):
raise TypeError(
f"{normal} is not a valid normal vector identifier. "
"Expected a string, integer or sequence of 3 floats."
)

if len(normal) != 3:
raise ValueError(
f"{normal} with length {len(normal)} is not a valid normal vector. "
"Expected a 3-element sequence."
)

try:
normal = np.array(normal, dtype="float64")
except ValueError as exc:
raise TypeError(
f"{normal} is not a valid normal vector: "
"some elements cannot be converted to float."
) from exc

if np.count_nonzero(normal) == 1:
return axis_names[np.nonzero(normal)[0][0]]

return normal


def SlicePlot(ds, normal, fields, *args, **kwargs):
r"""
A factory function for
:class:`yt.visualization.plot_window.AxisAlignedSlicePlot`
and :class:`yt.visualization.plot_window.OffAxisSlicePlot` objects. This
essentially allows for a single entry point to both types of slice plots,
the distinction being determined by the specified normal vector to the
slice.
projection.
The returned plot object can be updated using one of the many helper
functions defined in PlotWindow.
Expand All @@ -2311,18 +2359,19 @@ def SlicePlot(ds, normal=None, fields=None, axis=None, *args, **kwargs):
ds : :class:`yt.data_objects.static_output.Dataset`
This is the dataset object corresponding to the
simulation output to be plotted.
normal : int or one of 'x', 'y', 'z', or sequence of floats
This specifies the normal vector to the slice. If given as an integer
or a coordinate string (0=x, 1=y, 2=z), this function will return an
:class:`AxisAlignedSlicePlot` object. If given as a sequence of floats,
this is interpreted as an off-axis vector and an
:class:`OffAxisSlicePlot` object is returned.
fields : string
normal : int, str, or 3-element sequence of floats
This specifies the normal vector to the slice.
Valid int values are 0, 1 and 2. Coresponding str values depend on the
geometry of the dataset and are generally given by `ds.coordinates.axis_order`.
E.g. in cartesian they are 'x', 'y' and 'z'.
An arbitrary normal vector may be specified as a 3-element sequence of floats.
This function will return a :class:`OffAxisSlicePlot` object or a
:class:`AxisAlignedSlicePlot` object, depending on wether the requested
normal directions corresponds to a natural axis of the dataset's geometry.
fields : a (or a list of) 2-tuple of strings (ftype, fname)
The name of the field(s) to be plotted.
axis : int or one of 'x', 'y', 'z'
An int corresponding to the axis to slice along (0=x, 1=y, 2=z)
or the axis name itself. If specified, this will replace normal.
The following are nominally keyword arguments passed onto the respective
slice plot objects generated by this function.
Expand Down Expand Up @@ -2416,53 +2465,34 @@ def SlicePlot(ds, normal=None, fields=None, axis=None, *args, **kwargs):
Raises
------
AssertionError
If a proper normal axis is not specified via the normal or axis
keywords, and/or if a field to plot is not specified.
ValueError or TypeError
If `normal` cannot be interpreted as a valid normal direction.
Examples
--------
>>> from yt import load
>>> ds = load("IsolatedGalaxy/galaxy0030/galaxy0030")
>>> slc = SlicePlot(ds, "x", ("gas", "density"), center=[0.2, 0.3, 0.4])
>>> slc = slice_plot(ds, "x", ("gas", "density"), center=[0.2, 0.3, 0.4])
>>> slc = SlicePlot(
>>> slc = slice_plot(
... ds, [0.4, 0.2, -0.1], ("gas", "pressure"), north_vector=[0.2, -0.3, 0.1]
... )
"""
if axis is not None:
issue_deprecation_warning(
"SlicePlot's argument 'axis' is a deprecated alias for 'normal', it "
"will be removed in a future version of yt.",
since="4.0.0",
removal="4.1.0",
)
if normal is not None:
raise TypeError(
"SlicePlot() received incompatible arguments 'axis' and 'normal'"
)
normal = axis

# to keep positional ordering we had to make 'normal' and 'fields' keywords
if normal is None:
raise TypeError("Missing argument in SlicePlot(): 'normal'")
normal = _sanitize_normal_vector(ds, normal)

if fields is None:
raise TypeError("Missing argument in SlicePlot(): 'fields'")

# use an AxisAlignedSlicePlot where possible, e.g.:
# maybe someone passed normal=[0,0,0.2] when they should have just used "z"
if is_sequence(normal) and not isinstance(normal, str):
if np.count_nonzero(normal) == 1:
normal = ("x", "y", "z")[np.nonzero(normal)[0][0]]
else:
normal = np.array(normal, dtype="float64")
np.divide(normal, np.dot(normal, normal), normal)
if isinstance(normal, str):
# north_vector not used in AxisAlignedSlicePlots; remove it if in kwargs
if "north_vector" in kwargs:
mylog.warning(
"Ignoring 'north_vector' keyword as it is ill-defined for "
"an AxisAlignedSlicePlot object."
)
del kwargs["north_vector"]

# by now the normal should be properly set to get either a On/Off Axis plot
if is_sequence(normal) and not isinstance(normal, str):
return AxisAlignedSlicePlot(ds, normal, fields, *args, **kwargs)
else:
# OffAxisSlicePlot has hardcoded origin; remove it if in kwargs
if "origin" in kwargs:
mylog.warning(
Expand All @@ -2472,16 +2502,57 @@ def SlicePlot(ds, normal=None, fields=None, axis=None, *args, **kwargs):
del kwargs["origin"]

return OffAxisSlicePlot(ds, normal, fields, *args, **kwargs)
else:
# north_vector not used in AxisAlignedSlicePlots; remove it if in kwargs
if "north_vector" in kwargs:
mylog.warning(
"Ignoring 'north_vector' keyword as it is ill-defined for "
"an AxisAlignedSlicePlot object."
)
del kwargs["north_vector"]

return AxisAlignedSlicePlot(ds, normal, fields, *args, **kwargs)

def ProjectionPlot(ds, normal, fields, *args, **kwargs):
r"""
A factory function for
:class:`yt.visualization.plot_window.AxisAlignedProjectionPlot`
and :class:`yt.visualization.plot_window.OffAxisProjectionPlot` objects. This
essentially allows for a single entry point to both types of projection plots,
the distinction being determined by the specified normal vector to the
slice.
The returned plot object can be updated using one of the many helper
functions defined in PlotWindow.
Parameters
----------
ds : :class:`yt.data_objects.static_output.Dataset`
This is the dataset object corresponding to the
simulation output to be plotted.
normal : int, str, or 3-element sequence of floats
This specifies the normal vector to the slice.
Valid int values are 0, 1 and 2. Coresponding str values depend on the
geometry of the dataset and are generally given by `ds.coordinates.axis_order`.
E.g. in cartesian they are 'x', 'y' and 'z'.
An arbitrary normal vector may be specified as a 3-element sequence of floats.
This function will return a :class:`OffAxisSlicePlot` object or a
:class:`AxisAlignedSlicePlot` object, depending on wether the requested
normal directions corresponds to a natural axis of the dataset's geometry.
fields : a (or a list of) 2-tuple of strings (ftype, fname)
The name of the field(s) to be plotted.
Any additional positional and keyword arguments are passed down to the appropriate
return class. See :class:`yt.visualization.plot_window.AxisAlignedProjectionPlot`
and :class:`yt.visualization.plot_window.OffAxisProjectionPlot`.
Raises
------
ValueError or TypeError
If `normal` cannot be interpreted as a valid normal direction.
"""
normal = _sanitize_normal_vector(ds, normal)
if isinstance(normal, str):
return AxisAlignedProjectionPlot(ds, normal, fields, *args, **kwargs)
else:
return OffAxisProjectionPlot(ds, normal, fields, *args, **kwargs)


def plot_2d(
Expand Down Expand Up @@ -2627,3 +2698,18 @@ def plot_2d(
aspect=aspect,
data_source=data_source,
)


# Historically, both names (SlicePlot and ProjectionPlot) were introduced
# as classes (now AxisAlignedSlicePlot and AxisAlignedProjectionPlot), and
# repurposed as factory functions to extend the existing API to utilize more
# general sibling classes (OffAxisSlicePlot and OffAxisProjectionPlot) in a
# user-friendly way.
# For ProjectionPlot, see https://github.com/yt-project/yt/pull/3450
#
# Here we define the snake case aliases for these factory functions as a way
# to improve their discoverability. We may want to make these the prefered
# API in the future, and, less probably, the *only* API, but there is no plan
# for deprecating the historical names as of PR #3450 (yt 4.1dev)
projection_plot = ProjectionPlot
slice_plot = SlicePlot

0 comments on commit ce567af

Please sign in to comment.