From 6a1097e1009f004214f9b2abbba9f3e97cd46d56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Mon, 21 Feb 2022 13:28:58 +0100 Subject: [PATCH 1/2] ENH: add support for vaying colors in quiver annotations --- doc/source/visualizing/callbacks.rst | 47 ++++++- yt/visualization/plot_modifications.py | 181 ++++++++++++++++--------- 2 files changed, 156 insertions(+), 72 deletions(-) diff --git a/doc/source/visualizing/callbacks.rst b/doc/source/visualizing/callbacks.rst index d1b405bbe3a..100bd2e2854 100644 --- a/doc/source/visualizing/callbacks.rst +++ b/doc/source/visualizing/callbacks.rst @@ -273,20 +273,23 @@ Overplot Quivers Axis-Aligned Data Sources ^^^^^^^^^^^^^^^^^^^^^^^^^ -.. function:: annotate_quiver(self, field_x, field_y, factor=16, scale=None, \ +.. function:: annotate_quiver(self, field_x, field_y, field_c=None, factor=16, scale=None, \ scale_units=None, normalize=False, plot_args=None) (This is a proxy for :class:`~yt.visualization.plot_modifications.QuiverCallback`.) Adds a 'quiver' plot to any plot, using the ``field_x`` and ``field_y`` from - the associated data, skipping every ``factor`` datapoints in the - discretization. ``scale`` is the data units per arrow length unit using + the associated data, skipping every ``factor`` pixels in the + discretization. A third field, ``field_c``, can be used as color; which is the + counterpart of ``matplotlib.axes.Axes.quiver``'s final positional argument ``C``. + ``scale`` is the data units per arrow length unit using ``scale_units``. If ``normalize`` is ``True``, the fields will be scaled by their local (in-plane) length, allowing morphological features to be more clearly seen for fields with substantial variation in field strength. - Additional arguments can be passed to the ``plot_args`` dictionary, see - matplotlib.axes.Axes.quiver for more info. + All additional keyword arguments are passed down to ``matplotlib.Axes.axes.quiver``. + + Example using a constant color .. python-script:: @@ -301,10 +304,40 @@ Axis-Aligned Data Sources weight_field="density", width=(20, "kpc"), ) - p.annotate_quiver(("gas", "velocity_x"), ("gas", "velocity_y"), factor=16, - plot_args={"color": "purple"}) + p.annotate_quiver( + ("gas", "velocity_x"), + ("gas", "velocity_y"), + factor=16, + color="purple", + ) p.save() + + And now using a continuous colormap + +.. python-script:: + + import yt + + ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") + p = yt.ProjectionPlot( + ds, + "z", + ("gas", "density"), + center=[0.5, 0.5, 0.5], + weight_field="density", + width=(20, "kpc"), + ) + p.annotate_quiver( + ("gas", "velocity_x"), + ("gas", "velocity_y"), + ("gas", "vorticity_z"), + factor=16, + cmap="inferno_r", + ) + p.save() + + Off-Axis Data Sources ^^^^^^^^^^^^^^^^^^^^^ diff --git a/yt/visualization/plot_modifications.py b/yt/visualization/plot_modifications.py index 212a3ab558f..573c156eb5c 100644 --- a/yt/visualization/plot_modifications.py +++ b/yt/visualization/plot_modifications.py @@ -1,5 +1,7 @@ +import inspect import re import warnings +from abc import ABC, abstractmethod from functools import wraps from numbers import Integral, Number from typing import Any, Dict, Optional, Tuple, Union @@ -78,6 +80,8 @@ class PlotCallback: def __init_subclass__(cls, *args, **kwargs): super().__init_subclass__(*args, **kwargs) + if inspect.isabstract(cls): + return callback_registry[cls.__name__] = cls cls.__call__ = _verify_geometry(cls.__call__) @@ -357,7 +361,7 @@ def __call__(self, plot): qcb = CuttingQuiverCallback( (ftype, "cutting_plane_velocity_x"), (ftype, "cutting_plane_velocity_y"), - self.factor, + factor=self.factor, scale=self.scale, normalize=self.normalize, scale_units=self.scale_units, @@ -396,7 +400,7 @@ def __call__(self, plot): qcb = QuiverCallback( xv, yv, - self.factor, + factor=self.factor, scale=self.scale, scale_units=self.scale_units, normalize=self.normalize, @@ -451,7 +455,7 @@ def __call__(self, plot): qcb = CuttingQuiverCallback( (ftype, "cutting_plane_magnetic_field_x"), (ftype, "cutting_plane_magnetic_field_y"), - self.factor, + factor=self.factor, scale=self.scale, scale_units=self.scale_units, normalize=self.normalize, @@ -479,7 +483,7 @@ def __call__(self, plot): qcb = QuiverCallback( xv, yv, - self.factor, + factor=self.factor, scale=self.scale, scale_units=self.scale_units, normalize=self.normalize, @@ -488,10 +492,64 @@ def __call__(self, plot): return qcb(plot) -class QuiverCallback(PlotCallback): +class BaseQuiverCallback(PlotCallback, ABC): + def __init__( + self, + field_x, + field_y, + field_c=None, + *, + factor: Union[Tuple[int, int], int] = 16, + scale=None, + scale_units=None, + normalize=False, + plot_args=None, + **kwargs, + ): + PlotCallback.__init__(self) + self.field_x = field_x + self.field_y = field_y + self.field_c = field_c + self.factor = _validate_factor_tuple(factor) + self.scale = scale + self.scale_units = scale_units + self.normalize = normalize + if plot_args is None: + plot_args = kwargs + else: + # using plot_args should be deprecated at some point, + # but it needs to be done consistently for all callbacks + plot_args.update(kwargs) + + self.plot_args = plot_args + + @abstractmethod + def __call__(self, plot): + pass + + def _finalize(self, plot, X, Y, pixX, pixY, pixC): + if self.normalize: + nn = np.sqrt(pixX**2 + pixY**2) + pixX /= nn + pixY /= nn + + args = [X, Y, pixX, pixY] + if pixC is not None: + args.append(pixC) + + kwargs = dict( + scale=self.scale, + scale_units=self.scale_units, + ) + kwargs.update(self.plot_args) + return plot._axes.quiver(*args, **kwargs) + + +class QuiverCallback(BaseQuiverCallback): """ Adds a 'quiver' plot to any plot, using the *field_x* and *field_y* - from the associated data, skipping every *factor* datapoints. + from the associated data, skipping every *factor* pixels. + *field_c* is an optional field name used for color. *scale* is the data units per arrow length unit using *scale_units* and *plot_args* allows you to pass in matplotlib arguments (see matplotlib.axes.Axes.quiver for more info). if *normalize* is True, @@ -507,6 +565,8 @@ def __init__( self, field_x, field_y, + field_c=None, + *, factor: Union[Tuple[int, int], int] = 16, scale=None, scale_units=None, @@ -514,19 +574,21 @@ def __init__( bv_x=0, bv_y=0, plot_args=None, + **kwargs, ): - PlotCallback.__init__(self) - self.field_x = field_x - self.field_y = field_y + super().__init__( + field_x, + field_y, + field_c, + factor=factor, + scale=scale, + scale_units=scale_units, + normalize=normalize, + plot_args=plot_args, + **kwargs, + ) self.bv_x = bv_x self.bv_y = bv_y - self.factor = _validate_factor_tuple(factor) - self.scale = scale - self.scale_units = scale_units - self.normalize = normalize - if plot_args is None: - plot_args = {} - self.plot_args = plot_args def __call__(self, plot): x0, x1, y0, y1 = self._physical_bounds(plot) @@ -579,25 +641,27 @@ def _transformed_field(field, data): False, # antialias periodic, ) + if self.field_c is not None: + pixC = plot.data.ds.coordinates.pixelize( + plot.data.axis, + plot.data, + self.field_c, + bounds, + (nx, ny), + False, # antialias + periodic, + ) + else: + pixC = None + X, Y = np.meshgrid( np.linspace(xx0, xx1, nx, endpoint=True), np.linspace(yy0, yy1, ny, endpoint=True), ) - if self.normalize: - nn = np.sqrt(pixX**2 + pixY**2) - pixX /= nn - pixY /= nn - plot._axes.quiver( - X, - Y, - pixX, - pixY, - scale=self.scale, - scale_units=self.scale_units, - **self.plot_args, - ) + retv = self._finalize(plot, X, Y, pixX, pixY, pixC) plot._axes.set_xlim(xx0, xx1) plot._axes.set_ylim(yy0, yy1) + return retv class ContourCallback(PlotCallback): @@ -1131,7 +1195,7 @@ def __call__(self, plot): super().__call__(plot) -class CuttingQuiverCallback(PlotCallback): +class CuttingQuiverCallback(BaseQuiverCallback): """ Get a quiver plot on top of a cutting plane, using *field_x* and *field_y*, skipping every *factor* datapoint in the discretization. @@ -1146,27 +1210,6 @@ class CuttingQuiverCallback(PlotCallback): _type_name = "cquiver" _supported_geometries = ("cartesian", "spectral_cube") - def __init__( - self, - field_x, - field_y, - factor=16, - scale=None, - scale_units=None, - normalize=False, - plot_args=None, - ): - PlotCallback.__init__(self) - self.field_x = field_x - self.field_y = field_y - self.factor = _validate_factor_tuple(factor) - self.scale = scale - self.scale_units = scale_units - self.normalize = normalize - if plot_args is None: - plot_args = {} - self.plot_args = plot_args - def __call__(self, plot): x0, x1, y0, y1 = self._physical_bounds(plot) xx0, xx1, yy0, yy1 = self._plot_bounds(plot) @@ -1208,27 +1251,35 @@ def __call__(self, plot): plot.data[self.field_y], (x0, x1, y0, y1), ) + if self.field_c is not None: + pixC = np.zeros((ny, nx), dtype="f8") + pixelize_off_axis_cartesian( + pixC, + plot.data[("index", "x")].to("code_length"), + plot.data[("index", "y")].to("code_length"), + plot.data[("index", "z")].to("code_length"), + plot.data["px"], + plot.data["py"], + plot.data["pdx"], + plot.data["pdy"], + plot.data["pdz"], + plot.data.center, + plot.data._inv_mat, + indices, + plot.data[self.field_c], + (x0, x1, y0, y1), + ) + else: + pixC = None X, Y = np.meshgrid( np.linspace(xx0, xx1, nx, endpoint=True), np.linspace(yy0, yy1, ny, endpoint=True), ) - if self.normalize: - nn = np.sqrt(pixX**2 + pixY**2) - pixX /= nn - pixY /= nn - - plot._axes.quiver( - X, - Y, - pixX, - pixY, - scale=self.scale, - scale_units=self.scale_units, - **self.plot_args, - ) + retv = self._finalize(plot, X, Y, pixX, pixY, pixC) plot._axes.set_xlim(xx0, xx1) plot._axes.set_ylim(yy0, yy1) + return retv class ClumpContourCallback(PlotCallback): From e08826f200bf48598bfe13d1488cff131923138e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Mon, 21 Feb 2022 15:11:57 +0100 Subject: [PATCH 2/2] BUG: fix a bug where normalizing a quiverplot with 0-len vectors would yield runtims warnings from numpy (divide by 0 errors) --- yt/visualization/plot_modifications.py | 1 + 1 file changed, 1 insertion(+) diff --git a/yt/visualization/plot_modifications.py b/yt/visualization/plot_modifications.py index 573c156eb5c..f3bef38b979 100644 --- a/yt/visualization/plot_modifications.py +++ b/yt/visualization/plot_modifications.py @@ -530,6 +530,7 @@ def __call__(self, plot): def _finalize(self, plot, X, Y, pixX, pixY, pixC): if self.normalize: nn = np.sqrt(pixX**2 + pixY**2) + nn = np.where(nn == 0, 1, nn) pixX /= nn pixY /= nn