Skip to content

Commit

Permalink
ENH: add support for passing fields keys for colors and linewidth in …
Browse files Browse the repository at this point in the history
…streamline plot annotations
  • Loading branch information
neutrinoceros committed May 26, 2023
1 parent de65439 commit e1f69a5
Showing 1 changed file with 51 additions and 34 deletions.
85 changes: 51 additions & 34 deletions yt/visualization/plot_modifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import sys
import warnings
from abc import ABC, abstractmethod
from functools import update_wrapper
from functools import partial, update_wrapper
from numbers import Integral, Number
from typing import Any, Dict, Optional, Tuple, Type, Union

import matplotlib
import numpy as np
from unyt import unyt_quantity

from yt._maintenance.deprecation import issue_deprecation_warning
from yt._typing import AnyFieldKey
Expand Down Expand Up @@ -1224,19 +1225,37 @@ class StreamlineCallback(PlotCallback):

def __init__(
self,
field_x,
field_y,
field_x: AnyFieldKey,
field_y: AnyFieldKey,
*,
linewidth: Optional[Union[AnyFieldKey, float]] = None,
linewidth_upscaling: Union[float, unyt_quantity] = 1.0,
color: Optional[AnyFieldKey] = None,
factor: Union[Tuple[int, int], int] = 16,
density=1,
field_color=None,
display_threshold=None,
display_threshold: Optional[Union[float, unyt_quantity]] = None,
plot_args=None,
**kwargs,
):
self.field_x = field_x
self.field_y = field_y
self.field_color = field_color
if color is not None and field_color is not None:
raise TypeError(
"Cannot set `color` and `field_color` keyword arguments at the same time."
)
elif field_color is not None:
issue_deprecation_warning(
"The `field_color` keyword argument is deprecated. "
"Use `color` instead.",
since="4.3",
)
self.field_color = field_color
else:
self.field_color = color

self._linewidth = linewidth
self._linewidth_upscaling = linewidth_upscaling
self.factor = _validate_factor_tuple(factor)
self.dens = density
self.display_threshold = display_threshold
Expand All @@ -1254,44 +1273,41 @@ def __init__(
self.plot_args = plot_args

def __call__(self, plot):
bounds = self._physical_bounds(plot)
xx0, xx1, yy0, yy1 = self._plot_bounds(plot)

# We are feeding this size into the pixelizer, where it will properly
# set it in reverse order
nx = plot.raw_image_shape[1] // self.factor[0]
ny = plot.raw_image_shape[0] // self.factor[1]
pixX = plot.data.ds.coordinates.pixelize(
plot.data.axis, plot.data, self.field_x, bounds, (nx, ny)
)
pixY = plot.data.ds.coordinates.pixelize(
plot.data.axis, plot.data, self.field_y, bounds, (nx, ny)
pixelize = partial(
plot.data.ds.coordinates.pixelize,
plot.data.axis,
plot.data,
bounds=self._physical_bounds(plot),
size=(nx, ny),
)
if self.field_color:
field_colors = plot.data.ds.coordinates.pixelize(
plot.data.axis, plot.data, self.field_color, bounds, (nx, ny)
)

if self.display_threshold:
mask = field_colors > self.display_threshold
lwdefault = matplotlib.rcParams["lines.linewidth"]
pixX = pixelize(field=self.field_x)
pixY = pixelize(field=self.field_y)
if (
isinstance(self._linewidth, tuple)
and len(self._linewidth) == 2
and all(isinstance(_, str) for _ in self._linewidth)
):
linewidth = pixelize(field=self._linewidth)
if plot._swap_axes:
linewidth = linewidth.transpose()
else:
linewidth = self._linewidth

if "linewidth" in self.plot_args:
linewidth = self.plot_args["linewidth"]
else:
linewidth = lwdefault

try:
linewidth *= mask
self.plot_args["linewidth"] = linewidth
except ValueError as e:
err_msg = (
"Error applying display threshold: linewidth"
+ "must have shape ({}, {}) or be scalar"
)
err_msg = err_msg.format(nx, ny)
raise ValueError(err_msg) from e
if linewidth is not None:
linewidth *= self._linewidth_upscaling / np.abs(linewidth).max()

if self.field_color is not None:
field_colors = pixelize(field=self.field_color)
if plot._swap_axes:
field_colors = field_colors.transpose()
if linewidth is not None and self.display_threshold is not None:
linewidth *= field_colors > self.display_threshold
else:
field_colors = None

Expand All @@ -1313,6 +1329,7 @@ def __call__(self, plot):
"v": pixY,
"density": self.dens,
"color": field_colors,
"linewidth": linewidth,
}
streamplot_args.update(self.plot_args)
plot._axes.streamplot(**streamplot_args)
Expand Down

0 comments on commit e1f69a5

Please sign in to comment.