diff --git a/doc/source/visualizing/callbacks.rst b/doc/source/visualizing/callbacks.rst index 60dd11eedeb..1f7894d7c11 100644 --- a/doc/source/visualizing/callbacks.rst +++ b/doc/source/visualizing/callbacks.rst @@ -577,9 +577,8 @@ Overplot a Circle on a Plot Overplot Streamlines ~~~~~~~~~~~~~~~~~~~~ -.. function:: annotate_streamlines(self, field_x, field_y, *, factor=16, \ - density=1, display_threshold=None, \ - **kwargs) +.. function:: annotate_streamlines(self, field_x, field_y, *, linewidth=1.0, linewidth_upscaling=1.0, \ + color=None, color_threshold=float('-inf'), factor=16, **kwargs) (This is a proxy for :class:`~yt.visualization.plot_modifications.StreamlineCallback`.) @@ -591,6 +590,9 @@ Overplot Streamlines ``start_at_yedge``. A line with the qmean vector magnitude will cover 1.0/``factor`` of the image. + Additional keyword arguments are passed down to + `matplotlib.axes.Axes.streamplot `_ + .. python-script:: import yt diff --git a/yt/visualization/plot_modifications.py b/yt/visualization/plot_modifications.py index edc78f0d072..47d1a73545c 100644 --- a/yt/visualization/plot_modifications.py +++ b/yt/visualization/plot_modifications.py @@ -9,9 +9,10 @@ import matplotlib import numpy as np +from unyt import unyt_quantity from yt._maintenance.deprecation import issue_deprecation_warning -from yt._typing import AnyFieldKey +from yt._typing import AnyFieldKey, FieldKey from yt.data_objects.data_containers import YTDataContainer from yt.data_objects.level_sets.clump_handling import Clump from yt.data_objects.selection_objects.cut_region import YTCutRegion @@ -21,7 +22,12 @@ from yt.geometry.unstructured_mesh_handler import UnstructuredIndex from yt.units import dimensions from yt.units.yt_array import YTArray, YTQuantity, uhstack # type: ignore -from yt.utilities.exceptions import YTDataTypeUnsupported, YTUnsupportedPlotCallback +from yt.utilities.exceptions import ( + YTDataTypeUnsupported, + YTFieldNotFound, + YTFieldTypeNotFound, + YTUnsupportedPlotCallback, +) from yt.utilities.lib.geometry_utils import triangle_plane_intersect from yt.utilities.lib.line_integral_convolution import line_integral_convolution_2d from yt.utilities.lib.mesh_triangulation import triangulate_indices @@ -39,6 +45,11 @@ from yt.visualization.image_writer import apply_colormap from yt.visualization.plot_window import PWViewerMPL +if sys.version_info >= (3, 10): + from typing import TypeGuard +else: + from typing_extensions import TypeGuard + if sys.version_info >= (3, 11): from typing import assert_never else: @@ -1206,15 +1217,52 @@ def __call__(self, plot): plot._axes.text(xi, yi, "%d" % block_ids[n], clip_on=True) +# when type-checking with MPL >= 3.8, use +# from matplotlib.typing import ColorType +_ColorType = Any + + class StreamlineCallback(PlotCallback): """ - Add streamlines to any plot, using the *field_x* and *field_y* - from the associated data, skipping every *factor* datapoints like - 'quiver'. *density* is the index of the amount of the streamlines. - *field_color* is a field to be used to colormap the streamlines. - If *display_threshold* is supplied, any streamline segments where - *field_color* is less than the threshold will be removed by having - their line width set to 0. + Plot streamlines using matplotlib.axes.Axes.streamplot + + Arguments + --------- + + field_x: field key + The "velocity" analoguous field along the horizontal direction. + field_y: field key + The "velocity" analoguous field along the vertical direction. + + linewidth: float, or field key (default: 1.0) + A constant scalar will be passed directly to matplotlib.axes.Axes.streamplot + A field key will be first interpreted by yt and produce the adequate 2D array. + Data fields are normalized by their maximum value, so the maximal linewidth + is 1 by default. See `linewidth_upscaling` for fine tuning. + Note that the absolute value is taken in all cases. + + linewidth_upscaling: float (default: 1.0) + A constant multiplicative factor applied to linewidth. + Final linewidth is obtained as: + linewidth_upscaling * abs(linewidth) / max(abs(linewidth)) + + color: a color identifier, or a field key (default: matplotlib.rcParams['line.color']) + A constant color identifier will be passed directly to matplotlib.axes.Axes.streamplot + A field key will be first interpreted by yt and produce the adequate 2D array. + See https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.streamplot.html + for how to customize color mapping using `cmap` and `norm` arguments. + + color_threshold: float or unyt_quantity (default: -inf) + Regions where the field used for color is lower than this threshold will be masked. + Only used if color is a field key. + + factor: int, or tuple[int, int] (default: 16) + Fields are downed-sampled by this factor with respect to the background image + buffer size. A single integer factor will be used for both direction, but a tuple + of 2 integers can be passed to set x and y downsampling independently. + + **kwargs: any additional keyword arguments will be passed + directly to matplotlib.axes.Axes.streamplot """ _type_name = "streamlines" @@ -1229,22 +1277,56 @@ class StreamlineCallback(PlotCallback): def __init__( self, - field_x, - field_y, + field_x: AnyFieldKey, + field_y: AnyFieldKey, *, + linewidth: Union[float, AnyFieldKey] = 1.0, + linewidth_upscaling: float = 1.0, + color: Optional[Union[_ColorType, FieldKey]] = None, + color_threshold: Union[float, unyt_quantity] = float("-inf"), factor: Union[Tuple[int, int], int] = 16, - density=1, - field_color=None, - display_threshold=None, - plot_args=None, + field_color=None, # deprecated + display_threshold=None, # deprecated + plot_args=None, # deprecated **kwargs, ): self.field_x = field_x self.field_y = field_y - self.field_color = field_color + if color is not None and field_color is not None: + raise TypeError( + "`color` and `field_color` keyword arguments " + "cannot be set at the same time." + ) + elif field_color is not None: + issue_deprecation_warning( + "The `field_color` keyword argument is deprecated. " + "Use `color` instead.", + since="4.3", + stacklevel=5, + ) + self._color = field_color + else: + self._color = color + + if color_threshold is not None and display_threshold is not None: + raise TypeError( + "`color_threshold` and `display_threshold` keyword arguments " + "cannot be set at the same time." + ) + elif display_threshold is not None: + issue_deprecation_warning( + "The `display_threshold` keyword argument is deprecated. " + "Use `color_threshold` instead.", + since="4.3", + stacklevel=5, + ) + self._color_threshold = display_threshold + else: + self._color_threshold = color_threshold + + self._linewidth = linewidth + self._linewidth_upscaling = linewidth_upscaling self.factor = _validate_factor_tuple(factor) - self.dens = density - self.display_threshold = display_threshold if plot_args is not None: issue_deprecation_warning( @@ -1259,47 +1341,64 @@ def __init__( self.plot_args = plot_args - def __call__(self, plot): - bounds = self._physical_bounds(plot) + def __call__(self, plot) -> None: xx0, xx1, yy0, yy1 = self._plot_bounds(plot) # We are feeding this size into the pixelizer, where it will properly # set it in reverse order nx = plot.raw_image_shape[1] // self.factor[0] ny = plot.raw_image_shape[0] // self.factor[1] - pixX = plot.data.ds.coordinates.pixelize( - plot.data.axis, plot.data, self.field_x, bounds, (nx, ny) - ) - pixY = plot.data.ds.coordinates.pixelize( - plot.data.axis, plot.data, self.field_y, bounds, (nx, ny) - ) - if self.field_color: - field_colors = plot.data.ds.coordinates.pixelize( - plot.data.axis, plot.data, self.field_color, bounds, (nx, ny) + + def pixelize(field): + retv = plot.data.ds.coordinates.pixelize( + plot.data.axis, + plot.data, + field=field, + bounds=self._physical_bounds(plot), + size=(nx, ny), ) + if plot._swap_axes: + return retv.transpose() + else: + return retv - if self.display_threshold: - mask = field_colors > self.display_threshold - lwdefault = matplotlib.rcParams["lines.linewidth"] + def is_field_key(val) -> TypeGuard[AnyFieldKey]: + if not is_sequence(val): + return False + try: + plot.data._determine_fields(val) + except (YTFieldNotFound, YTFieldTypeNotFound): + return False + else: + return True - if "linewidth" in self.plot_args: - linewidth = self.plot_args["linewidth"] - else: - linewidth = lwdefault - - try: - linewidth *= mask - self.plot_args["linewidth"] = linewidth - except ValueError as e: - err_msg = ( - "Error applying display threshold: linewidth" - + "must have shape ({}, {}) or be scalar" - ) - err_msg = err_msg.format(nx, ny) - raise ValueError(err_msg) from e + pixX = pixelize(self.field_x) + pixY = pixelize(self.field_y) + if isinstance(self._linewidth, (int, float)): + linewidth = self._linewidth_upscaling * self._linewidth + elif is_field_key(self._linewidth): + linewidth = pixelize(self._linewidth) + linewidth *= self._linewidth_upscaling / np.abs(linewidth).max() else: - field_colors = None + raise TypeError( + f"annotate_streamlines received linewidth={self._linewidth!r}, " + f"with type {type(self._linewidth)}. Expected a float or a field key." + ) + + if is_field_key(self._color): + color = pixelize(self._color) + linewidth *= color > self._color_threshold + else: + if (_cmap := self.plot_args.get("cmap")) is not None: + warnings.warn( + f"annotate_streamlines received color={self._color!r}, " + "which wasn't recognized as as field key. " + "It is assumed to be a fixed color identifier. " + f"Also received cmap={_cmap!r}, which will be ignored.", + stacklevel=5, + ) + color = self._color X, Y = ( np.linspace(xx0, xx1, nx, endpoint=True), @@ -1308,8 +1407,6 @@ def __call__(self, plot): X, Y, pixX, pixY = self._sanitize_xy_order(plot, X, Y, pixX, pixY) if plot._swap_axes: # need an additional transpose here for streamline tracing - pixX = pixX.transpose() - pixY = pixY.transpose() X = X.transpose() Y = Y.transpose() streamplot_args = { @@ -1317,10 +1414,10 @@ def __call__(self, plot): "y": Y, "u": pixX, "v": pixY, - "density": self.dens, - "color": field_colors, + "color": color, + "linewidth": linewidth, + **self.plot_args, } - streamplot_args.update(self.plot_args) plot._axes.streamplot(**streamplot_args) self._set_plot_limits(plot, (xx0, xx1, yy0, yy1)) diff --git a/yt/visualization/tests/test_callbacks.py b/yt/visualization/tests/test_callbacks.py index d3eb20b152c..dea272baf4d 100644 --- a/yt/visualization/tests/test_callbacks.py +++ b/yt/visualization/tests/test_callbacks.py @@ -913,17 +913,20 @@ def test_streamline_callback(): p.annotate_streamlines( ("gas", "velocity_x"), ("gas", "velocity_y"), - field_color=("stream", "magvel"), + color=("stream", "magvel"), ) assert_fname(p.save(prefix)[0]) check_axis_manipulation(p, prefix) + # a more thorough example involving many keyword arguments p = SlicePlot(ds, ax, ("gas", "density")) p.annotate_streamlines( ("gas", "velocity_x"), ("gas", "velocity_y"), - field_color=("stream", "magvel"), - display_threshold=0.5, + linewidth=("gas", "density"), + linewidth_upscaling=3, + color=("stream", "magvel"), + color_threshold=0.5, cmap="viridis", arrowstyle="->", )