Skip to content

Commit

Permalink
ENH: add support for passing fields keys for colors and linewidth in …
Browse files Browse the repository at this point in the history
…streamline plot annotations
  • Loading branch information
neutrinoceros committed May 28, 2023
1 parent 1814957 commit 2424f9c
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 59 deletions.
8 changes: 5 additions & 3 deletions doc/source/visualizing/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -577,9 +577,8 @@ Overplot a Circle on a Plot
Overplot Streamlines
~~~~~~~~~~~~~~~~~~~~

.. function:: annotate_streamlines(self, field_x, field_y, *, factor=16, \
density=1, display_threshold=None, \
**kwargs)
.. function:: annotate_streamlines(self, field_x, field_y, *, linewidth=1.0, linewidth_upscaling=1.0, \
color=None, color_threshold=float('-inf'), factor=16, **kwargs)
(This is a proxy for
:class:`~yt.visualization.plot_modifications.StreamlineCallback`.)
Expand All @@ -591,6 +590,9 @@ Overplot Streamlines
``start_at_yedge``. A line with the qmean vector magnitude will cover
1.0/``factor`` of the image.

Additional keyword arguments are passed down to
`matplotlib.axes.Axes.streamplot <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.streamplot.html>`_

.. python-script::

import yt
Expand Down
204 changes: 151 additions & 53 deletions yt/visualization/plot_modifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

import matplotlib
import numpy as np
from unyt import unyt_quantity

from yt._maintenance.deprecation import issue_deprecation_warning
from yt._typing import AnyFieldKey
from yt._typing import AnyFieldKey, FieldKey
from yt.data_objects.data_containers import YTDataContainer
from yt.data_objects.level_sets.clump_handling import Clump
from yt.data_objects.selection_objects.cut_region import YTCutRegion
Expand All @@ -21,7 +22,12 @@
from yt.geometry.unstructured_mesh_handler import UnstructuredIndex
from yt.units import dimensions
from yt.units.yt_array import YTArray, YTQuantity, uhstack # type: ignore
from yt.utilities.exceptions import YTDataTypeUnsupported, YTUnsupportedPlotCallback
from yt.utilities.exceptions import (
YTDataTypeUnsupported,
YTFieldNotFound,
YTFieldTypeNotFound,
YTUnsupportedPlotCallback,
)
from yt.utilities.lib.geometry_utils import triangle_plane_intersect
from yt.utilities.lib.line_integral_convolution import line_integral_convolution_2d
from yt.utilities.lib.mesh_triangulation import triangulate_indices
Expand All @@ -39,6 +45,11 @@
from yt.visualization.image_writer import apply_colormap
from yt.visualization.plot_window import PWViewerMPL

if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard

if sys.version_info >= (3, 11):
from typing import assert_never
else:
Expand Down Expand Up @@ -1201,15 +1212,52 @@ def __call__(self, plot):
plot._axes.text(xi, yi, "%d" % block_ids[n], clip_on=True)


# when type-checking with MPL >= 3.8, use
# from matplotlib.typing import ColorType
_ColorType = Any


class StreamlineCallback(PlotCallback):
"""
Add streamlines to any plot, using the *field_x* and *field_y*
from the associated data, skipping every *factor* datapoints like
'quiver'. *density* is the index of the amount of the streamlines.
*field_color* is a field to be used to colormap the streamlines.
If *display_threshold* is supplied, any streamline segments where
*field_color* is less than the threshold will be removed by having
their line width set to 0.
Plot streamlines using matplotlib.axes.Axes.streamplot
Arguments
---------
field_x: field key
The "velocity" analoguous field along the horizontal direction.
field_y: field key
The "velocity" analoguous field along the vertical direction.
linewidth: float, or field key (default: 1.0)
A constant scalar will be passed directly to matplotlib.axes.Axes.streamplot
A field key will be first interpreted by yt and produce the adequate 2D array.
Data fields are normalized by their maximum value, so the maximal linewidth
is 1 by default. See `linewidth_upscaling` for fine tuning.
Note that the absolute value is taken in all cases.
linewidth_upscaling: float (default: 1.0)
A constant multiplicative factor applied to linewidth.
Final linewidth is obtained as:
linewidth_upscaling * abs(linewidth) / max(abs(linewidth))
color: a color identifier, or a field key (default: matplotlib.rcParams['line.color'])
A constant color identifier will be passed directly to matplotlib.axes.Axes.streamplot
A field key will be first interpreted by yt and produce the adequate 2D array.
See https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.streamplot.html
for how to customize color mapping using `cmap` and `norm` arguments.
color_threshold: float or unyt_quantity (default: -inf)
Regions where the field used for color is lower than this threshold will be masked.
Only used if color is a field key.
factor: int, or tuple[int, int] (default: 16)
Fields are downed-sampled by this factor with respect to the background image
buffer size. A single integer factor will be used for both direction, but a tuple
of 2 integers can be passed to set x and y downsampling independently.
**kwargs: any additional keyword arguments will be passed
directly to matplotlib.axes.Axes.streamplot
"""

_type_name = "streamlines"
Expand All @@ -1224,76 +1272,128 @@ class StreamlineCallback(PlotCallback):

def __init__(
self,
field_x,
field_y,
field_x: AnyFieldKey,
field_y: AnyFieldKey,
*,
linewidth: Union[float, AnyFieldKey] = 1.0,
linewidth_upscaling: float = 1.0,
color: Optional[Union[_ColorType, FieldKey]] = None,
color_threshold: Union[float, unyt_quantity] = float("-inf"),
factor: Union[Tuple[int, int], int] = 16,
density=1,
field_color=None,
display_threshold=None,
plot_args=None,
field_color=None, # deprecated
display_threshold=None, # deprecated
plot_args=None, # deprecated
**kwargs,
):
self.field_x = field_x
self.field_y = field_y
self.field_color = field_color
if color is not None and field_color is not None:
raise TypeError(
"`color` and `field_color` keyword arguments "
"cannot be set at the same time."
)
elif field_color is not None:
issue_deprecation_warning(
"The `field_color` keyword argument is deprecated. "
"Use `color` instead.",
since="4.3",
stacklevel=5,
)
self._color = field_color
else:
self._color = color

if color_threshold is not None and display_threshold is not None:
raise TypeError(
"`color_threshold` and `display_threshold` keyword arguments "
"cannot be set at the same time."
)
elif display_threshold is not None:
issue_deprecation_warning(
"The `display_threshold` keyword argument is deprecated. "
"Use `color_threshold` instead.",
since="4.3",
stacklevel=5,
)
self._color_threshold = display_threshold
else:
self._color_threshold = color_threshold

self._linewidth = linewidth
self._linewidth_upscaling = linewidth_upscaling
self.factor = _validate_factor_tuple(factor)
self.dens = density
self.display_threshold = display_threshold

if plot_args is not None:
issue_deprecation_warning(
"`plot_args` is deprecated. "
"You can now pass arbitrary keyword arguments instead of a dictionary.",
since="4.1.0",
stacklevel=5,
)
plot_args.update(kwargs)
else:
plot_args = kwargs

self.plot_args = plot_args

def __call__(self, plot):
bounds = self._physical_bounds(plot)
def __call__(self, plot) -> None:
xx0, xx1, yy0, yy1 = self._plot_bounds(plot)

# We are feeding this size into the pixelizer, where it will properly
# set it in reverse order
nx = plot.raw_image_shape[1] // self.factor[0]
ny = plot.raw_image_shape[0] // self.factor[1]
pixX = plot.data.ds.coordinates.pixelize(
plot.data.axis, plot.data, self.field_x, bounds, (nx, ny)
)
pixY = plot.data.ds.coordinates.pixelize(
plot.data.axis, plot.data, self.field_y, bounds, (nx, ny)
)
if self.field_color:
field_colors = plot.data.ds.coordinates.pixelize(
plot.data.axis, plot.data, self.field_color, bounds, (nx, ny)

def pixelize(field):
retv = plot.data.ds.coordinates.pixelize(
plot.data.axis,
plot.data,
field=field,
bounds=self._physical_bounds(plot),
size=(nx, ny),
)
if plot._swap_axes:
return retv.transpose()
else:
return retv

if self.display_threshold:
mask = field_colors > self.display_threshold
lwdefault = matplotlib.rcParams["lines.linewidth"]
def is_field_key(val) -> TypeGuard[AnyFieldKey]:
if not is_sequence(val):
return False
try:
plot.data._determine_fields(val)
except (YTFieldNotFound, YTFieldTypeNotFound):
return False
else:
return True

if "linewidth" in self.plot_args:
linewidth = self.plot_args["linewidth"]
else:
linewidth = lwdefault

try:
linewidth *= mask
self.plot_args["linewidth"] = linewidth
except ValueError as e:
err_msg = (
"Error applying display threshold: linewidth"
+ "must have shape ({}, {}) or be scalar"
)
err_msg = err_msg.format(nx, ny)
raise ValueError(err_msg) from e
pixX = pixelize(self.field_x)
pixY = pixelize(self.field_y)

if isinstance(self._linewidth, (int, float)):
linewidth = self._linewidth_upscaling * self._linewidth
elif is_field_key(self._linewidth):
linewidth = pixelize(self._linewidth)
linewidth *= self._linewidth_upscaling / np.abs(linewidth).max()
else:
field_colors = None
raise TypeError(
f"annotate_streamlines received linewidth={self._linewidth!r}, "
f"with type {type(self._linewidth)}. Expected a float or a field key."
)

if is_field_key(self._color):
color = pixelize(self._color)
linewidth *= color > self._color_threshold
else:
if (_cmap := self.plot_args.get("cmap")) is not None:
warnings.warn(
f"annotate_streamlines received color={self._color!r}, "
"which wasn't recognized as as field key. "
"It is assumed to be a fixed color identifier. "
f"Also received cmap={_cmap!r}, which will be ignored.",
stacklevel=5,
)
color = self._color

X, Y = (
np.linspace(xx0, xx1, nx, endpoint=True),
Expand All @@ -1302,19 +1402,17 @@ def __call__(self, plot):
X, Y, pixX, pixY = self._sanitize_xy_order(plot, X, Y, pixX, pixY)
if plot._swap_axes:
# need an additional transpose here for streamline tracing
pixX = pixX.transpose()
pixY = pixY.transpose()
X = X.transpose()
Y = Y.transpose()
streamplot_args = {
"x": X,
"y": Y,
"u": pixX,
"v": pixY,
"density": self.dens,
"color": field_colors,
"color": color,
"linewidth": linewidth,
**self.plot_args,
}
streamplot_args.update(self.plot_args)
plot._axes.streamplot(**streamplot_args)
self._set_plot_limits(plot, (xx0, xx1, yy0, yy1))

Expand Down
9 changes: 6 additions & 3 deletions yt/visualization/tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,17 +914,20 @@ def test_streamline_callback():
p.annotate_streamlines(
("gas", "velocity_x"),
("gas", "velocity_y"),
field_color=("stream", "magvel"),
color=("stream", "magvel"),
)
assert_fname(p.save(prefix)[0])
check_axis_manipulation(p, prefix)

# a more thorough example involving many keyword arguments
p = SlicePlot(ds, ax, ("gas", "density"))
p.annotate_streamlines(
("gas", "velocity_x"),
("gas", "velocity_y"),
field_color=("stream", "magvel"),
display_threshold=0.5,
linewidth=("gas", "density"),
linewidth_upscaling=3,
color=("stream", "magvel"),
color_threshold=0.5,
cmap="viridis",
arrowstyle="->",
)
Expand Down

0 comments on commit 2424f9c

Please sign in to comment.