diff --git a/yt/visualization/plot_modifications.py b/yt/visualization/plot_modifications.py index ee991deba59..64b51da8ca1 100644 --- a/yt/visualization/plot_modifications.py +++ b/yt/visualization/plot_modifications.py @@ -3,12 +3,13 @@ import sys import warnings from abc import ABC, abstractmethod -from functools import update_wrapper +from functools import partial, update_wrapper from numbers import Integral, Number from typing import Any, Dict, Optional, Tuple, Type, Union import matplotlib import numpy as np +from unyt import unyt_quantity from yt._maintenance.deprecation import issue_deprecation_warning from yt._typing import AnyFieldKey @@ -1224,19 +1225,37 @@ class StreamlineCallback(PlotCallback): def __init__( self, - field_x, - field_y, + field_x: AnyFieldKey, + field_y: AnyFieldKey, *, + linewidth: Optional[Union[AnyFieldKey, float]] = None, + linewidth_upscaling: Union[float, unyt_quantity] = 1.0, + color: Optional[AnyFieldKey] = None, factor: Union[Tuple[int, int], int] = 16, density=1, field_color=None, - display_threshold=None, + display_threshold: Optional[Union[float, unyt_quantity]] = None, plot_args=None, **kwargs, ): self.field_x = field_x self.field_y = field_y - self.field_color = field_color + if color is not None and field_color is not None: + raise TypeError( + "Cannot set `color` and `field_color` keyword arguments at the same time." + ) + elif field_color is not None: + issue_deprecation_warning( + "The `field_color` keyword argument is deprecated. " + "Use `color` instead.", + since="4.3", + ) + self.field_color = field_color + else: + self.field_color = color + + self._linewidth = linewidth + self._linewidth_upscaling = linewidth_upscaling self.factor = _validate_factor_tuple(factor) self.dens = density self.display_threshold = display_threshold @@ -1254,44 +1273,41 @@ def __init__( self.plot_args = plot_args def __call__(self, plot): - bounds = self._physical_bounds(plot) xx0, xx1, yy0, yy1 = self._plot_bounds(plot) # We are feeding this size into the pixelizer, where it will properly # set it in reverse order nx = plot.raw_image_shape[1] // self.factor[0] ny = plot.raw_image_shape[0] // self.factor[1] - pixX = plot.data.ds.coordinates.pixelize( - plot.data.axis, plot.data, self.field_x, bounds, (nx, ny) - ) - pixY = plot.data.ds.coordinates.pixelize( - plot.data.axis, plot.data, self.field_y, bounds, (nx, ny) + pixelize = partial( + plot.data.ds.coordinates.pixelize, + plot.data.axis, + plot.data, + bounds=self._physical_bounds(plot), + size=(nx, ny), ) - if self.field_color: - field_colors = plot.data.ds.coordinates.pixelize( - plot.data.axis, plot.data, self.field_color, bounds, (nx, ny) - ) - - if self.display_threshold: - mask = field_colors > self.display_threshold - lwdefault = matplotlib.rcParams["lines.linewidth"] + pixX = pixelize(field=self.field_x) + pixY = pixelize(field=self.field_y) + if ( + isinstance(self._linewidth, tuple) + and len(self._linewidth) == 2 + and all(isinstance(_, str) for _ in self._linewidth) + ): + linewidth = pixelize(field=self._linewidth) + if plot._swap_axes: + linewidth = linewidth.transpose() + else: + linewidth = self._linewidth - if "linewidth" in self.plot_args: - linewidth = self.plot_args["linewidth"] - else: - linewidth = lwdefault - - try: - linewidth *= mask - self.plot_args["linewidth"] = linewidth - except ValueError as e: - err_msg = ( - "Error applying display threshold: linewidth" - + "must have shape ({}, {}) or be scalar" - ) - err_msg = err_msg.format(nx, ny) - raise ValueError(err_msg) from e + if linewidth is not None: + linewidth *= self._linewidth_upscaling / np.abs(linewidth).max() + if self.field_color is not None: + field_colors = pixelize(field=self.field_color) + if plot._swap_axes: + field_colors = field_colors.transpose() + if linewidth is not None and self.display_threshold is not None: + linewidth *= field_colors > self.display_threshold else: field_colors = None @@ -1313,6 +1329,7 @@ def __call__(self, plot): "v": pixY, "density": self.dens, "color": field_colors, + "linewidth": linewidth, } streamplot_args.update(self.plot_args) plot._axes.streamplot(**streamplot_args)