Skip to content

Commit

Permalink
ENH: add support for colored quiver callback
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Feb 21, 2022
1 parent 395874e commit bdf36b7
Showing 1 changed file with 116 additions and 62 deletions.
178 changes: 116 additions & 62 deletions yt/visualization/plot_modifications.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import inspect
import re
import warnings
from abc import ABC, abstractmethod
from functools import wraps
from numbers import Integral, Number
from typing import Any, Dict, Optional, Tuple, Union
Expand Down Expand Up @@ -78,6 +80,8 @@ class PlotCallback:

def __init_subclass__(cls, *args, **kwargs):
super().__init_subclass__(*args, **kwargs)
if inspect.isabstract(cls):
return
callback_registry[cls.__name__] = cls
cls.__call__ = _verify_geometry(cls.__call__)

Expand Down Expand Up @@ -451,7 +455,7 @@ def __call__(self, plot):
qcb = CuttingQuiverCallback(
(ftype, "cutting_plane_magnetic_field_x"),
(ftype, "cutting_plane_magnetic_field_y"),
self.factor,
factor=self.factor,
scale=self.scale,
scale_units=self.scale_units,
normalize=self.normalize,
Expand Down Expand Up @@ -479,7 +483,7 @@ def __call__(self, plot):
qcb = QuiverCallback(
xv,
yv,
self.factor,
factor=self.factor,
scale=self.scale,
scale_units=self.scale_units,
normalize=self.normalize,
Expand All @@ -488,10 +492,67 @@ def __call__(self, plot):
return qcb(plot)


class QuiverCallback(PlotCallback):
class BaseQuiverCallback(PlotCallback, ABC):
def __init__(
self,
field_x,
field_y,
field_c=None,
*,
factor: Union[Tuple[int, int], int] = 16,
scale=None,
scale_units=None,
normalize=False,
plot_args=None,
**kwargs,
):
PlotCallback.__init__(self)
self.field_x = field_x
self.field_y = field_y
self.field_c = field_c
self.factor = _validate_factor_tuple(factor)
self.scale = scale
self.scale_units = scale_units
self.normalize = normalize
if plot_args is None:
plot_args = kwargs
else:
issue_deprecation_warning(
"Passing a dictionary 'plot_args' argument to annotate_quiver methods is deprecated. "
"Instead, you can use arbitrary keword arguments to be passed to matplotlib.Axes.axes.quiver",
since="4.1.0",
)
plot_args.update(kwargs)

self.plot_args = plot_args

@abstractmethod
def __call__(self, plot):
pass

def _finalize(self, plot, X, Y, pixX, pixY, pixC):
if self.normalize:
nn = np.sqrt(pixX**2 + pixY**2)
pixX /= nn
pixY /= nn

args = [X, Y, pixX, pixY]
if pixC is not None:
args.append(pixC)

kwargs = dict(
scale=self.scale,
scale_units=self.scale_units,
)
kwargs.update(self.plot_args)
return plot._axes.quiver(*args, **kwargs)


class QuiverCallback(BaseQuiverCallback):
"""
Adds a 'quiver' plot to any plot, using the *field_x* and *field_y*
from the associated data, skipping every *factor* datapoints.
*field_c* is an optional field name used for color.
*scale* is the data units per arrow length unit using *scale_units*
and *plot_args* allows you to pass in matplotlib arguments (see
matplotlib.axes.Axes.quiver for more info). if *normalize* is True,
Expand All @@ -507,26 +568,30 @@ def __init__(
self,
field_x,
field_y,
field_c=None,
*,
factor: Union[Tuple[int, int], int] = 16,
scale=None,
scale_units=None,
normalize=False,
bv_x=0,
bv_y=0,
plot_args=None,
**kwargs,
):
PlotCallback.__init__(self)
self.field_x = field_x
self.field_y = field_y
super().__init__(
field_x,
field_y,
field_c,
factor=factor,
scale=scale,
scale_units=scale_units,
normalize=normalize,
plot_args=plot_args,
**kwargs,
)
self.bv_x = bv_x
self.bv_y = bv_y
self.factor = _validate_factor_tuple(factor)
self.scale = scale
self.scale_units = scale_units
self.normalize = normalize
if plot_args is None:
plot_args = {}
self.plot_args = plot_args

def __call__(self, plot):
x0, x1, y0, y1 = self._physical_bounds(plot)
Expand Down Expand Up @@ -579,25 +644,27 @@ def _transformed_field(field, data):
False, # antialias
periodic,
)
if self.field_c is not None:
pixC = plot.data.ds.coordinates.pixelize(
plot.data.axis,
plot.data,
self.field_c,
bounds,
(nx, ny),
False, # antialias
periodic,
)
else:
pixC = None

X, Y = np.meshgrid(
np.linspace(xx0, xx1, nx, endpoint=True),
np.linspace(yy0, yy1, ny, endpoint=True),
)
if self.normalize:
nn = np.sqrt(pixX**2 + pixY**2)
pixX /= nn
pixY /= nn
plot._axes.quiver(
X,
Y,
pixX,
pixY,
scale=self.scale,
scale_units=self.scale_units,
**self.plot_args,
)
retv = self._finalize(plot, X, Y, pixX, pixY, pixC)
plot._axes.set_xlim(xx0, xx1)
plot._axes.set_ylim(yy0, yy1)
return retv


class ContourCallback(PlotCallback):
Expand Down Expand Up @@ -1131,7 +1198,7 @@ def __call__(self, plot):
super().__call__(plot)


class CuttingQuiverCallback(PlotCallback):
class CuttingQuiverCallback(BaseQuiverCallback):
"""
Get a quiver plot on top of a cutting plane, using *field_x* and
*field_y*, skipping every *factor* datapoint in the discretization.
Expand All @@ -1146,27 +1213,6 @@ class CuttingQuiverCallback(PlotCallback):
_type_name = "cquiver"
_supported_geometries = ("cartesian", "spectral_cube")

def __init__(
self,
field_x,
field_y,
factor=16,
scale=None,
scale_units=None,
normalize=False,
plot_args=None,
):
PlotCallback.__init__(self)
self.field_x = field_x
self.field_y = field_y
self.factor = _validate_factor_tuple(factor)
self.scale = scale
self.scale_units = scale_units
self.normalize = normalize
if plot_args is None:
plot_args = {}
self.plot_args = plot_args

def __call__(self, plot):
x0, x1, y0, y1 = self._physical_bounds(plot)
xx0, xx1, yy0, yy1 = self._plot_bounds(plot)
Expand Down Expand Up @@ -1208,27 +1254,35 @@ def __call__(self, plot):
plot.data[self.field_y],
(x0, x1, y0, y1),
)
if self.field_c is not None:
pixC = np.zeros((ny, nx), dtype="f8")
pixelize_off_axis_cartesian(
pixC,
plot.data[("index", "x")].to("code_length"),
plot.data[("index", "y")].to("code_length"),
plot.data[("index", "z")].to("code_length"),
plot.data["px"],
plot.data["py"],
plot.data["pdx"],
plot.data["pdy"],
plot.data["pdz"],
plot.data.center,
plot.data._inv_mat,
indices,
plot.data[self.field_c],
(x0, x1, y0, y1),
)
else:
pixC = None
X, Y = np.meshgrid(
np.linspace(xx0, xx1, nx, endpoint=True),
np.linspace(yy0, yy1, ny, endpoint=True),
)

if self.normalize:
nn = np.sqrt(pixX**2 + pixY**2)
pixX /= nn
pixY /= nn

plot._axes.quiver(
X,
Y,
pixX,
pixY,
scale=self.scale,
scale_units=self.scale_units,
**self.plot_args,
)
retv = self._finalize(plot, X, Y, pixX, pixY, pixC)
plot._axes.set_xlim(xx0, xx1)
plot._axes.set_ylim(yy0, yy1)
return retv


class ClumpContourCallback(PlotCallback):
Expand Down

0 comments on commit bdf36b7

Please sign in to comment.