From d6f49d237cdb685b57fc1e36875e3b8c102bd0d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sun, 20 Mar 2022 09:03:35 +0100 Subject: [PATCH 01/54] RFC: implement NormHandler and ColorbarHandler --- yt/visualization/_handlers.py | 354 ++++++++++++++++++++++++++++++++++ 1 file changed, 354 insertions(+) create mode 100644 yt/visualization/_handlers.py diff --git a/yt/visualization/_handlers.py b/yt/visualization/_handlers.py new file mode 100644 index 0000000000..8e966bc09b --- /dev/null +++ b/yt/visualization/_handlers.py @@ -0,0 +1,354 @@ +import weakref +from numbers import Real +from typing import Any, Dict, List, Optional, Type, Union + +import numpy as np +from matplotlib.cm import get_cmap +from matplotlib.colors import Colormap, LogNorm, Normalize, SymLogNorm +from packaging.version import Version +from unyt import Unit, unyt_quantity + +from yt.config import ytcfg +from yt.funcs import get_brewer_cmap, is_sequence, mylog +from yt.visualization._commons import MPL_VERSION + + +class NormHandler: + """ + A bookkeeper class that can hold a fully defined norm object, or dynamically + build one on demand according to a set of constraints. + + If a fully defined norm object is added, any existing constraints are + dropped, and vice versa. These rules are implemented with properties and + watcher patterns. + + It also keeps track of display units so that vmin, vmax and linthresh can be + updated with implicit units. + """ + + # using slots here to minimize the risk of introducing bugs + # since attributes names are essential to this class's implementation + __slots__ = ( + "data_source", + "ds", + "_display_units", + "_vmin", + "_vmax", + "_norm_type", + "_linthresh", + "_norm_type", + "_norm", + ) + _constraint_attrs: List[str] = ["vmin", "vmax", "norm_type", "linthresh"] + + def __init__( + self, + data_source, + *, + display_units: Unit, + vmin: Optional[unyt_quantity] = None, + vmax: Optional[unyt_quantity] = None, + norm_type: Optional[Type[Normalize]] = None, + norm: Optional[Normalize] = None, + linthresh: Optional[float] = None, + ): + self.data_source = weakref.proxy(data_source) + self.ds = data_source.ds # should already be a weakref proxy + self._display_units = display_units + + self._norm = norm + self._vmin = vmin + self._vmax = vmax + self._norm_type = norm_type + self._linthresh = linthresh + + if self.has_norm and self.has_constraints: + raise TypeError( + "NormHandler input is malformed. " + "A norm cannot be passed along other constraints." + ) + + def _get_constraints(self) -> Dict[str, Any]: + return { + attr: getattr(self, attr) + for attr in self.__class__._constraint_attrs + if getattr(self, attr) is not None + } + + @property + def has_constraints(self) -> bool: + return bool(self._get_constraints()) + + def _reset_constraints(self) -> None: + constraints = self._get_constraints() + if not constraints: + return + + msg = ", ".join([f"{name}={value}" for name, value in constraints.items()]) + mylog.warning("Dropping norm constraints (%s)", msg) + for name in constraints.keys(): + setattr(self, name, None) + + @property + def has_norm(self) -> bool: + return self._norm is not None + + def _reset_norm(self): + if not self.has_norm: + return + mylog.warning("Dropping norm (%s)", self.norm) + self._norm = None + + def to_float(self, val: unyt_quantity) -> float: + return float(val.to(self.display_units).d) + + def to_quan(self, val) -> unyt_quantity: + if isinstance(val, unyt_quantity): + return self.ds.quan(val) + elif ( + is_sequence(val) + and len(val) == 2 + and isinstance(val[0], Real) + and isinstance(val[1], (str, Unit)) + ): + return self.ds.quan(*val) + elif isinstance(val, Real): + return self.ds.quan(val, self.display_units) + else: + raise TypeError(f"Could not convert {val!r} to unyt_quantity") + + @property + def display_units(self) -> Unit: + return self._display_units + + @display_units.setter + def display_units(self, newval: Union[str, Unit]) -> None: + self._display_units = Unit(newval) + + def _set_quan_attr( + self, attr: str, newval: Optional[Union[unyt_quantity, float]] + ) -> None: + if newval is None: + setattr(self, attr, None) + elif isinstance(newval, Real): + setattr(self, attr, newval * self.display_units) + else: + try: + quan = self.to_quan(newval) + except TypeError as exc: + raise TypeError( + "Expected None, a float, or a unyt_quantity, " + f"received {newval} with type {type(newval)}" + ) from exc + else: + setattr(self, attr, quan) + + @property + def vmin(self) -> Optional[unyt_quantity]: + return self._vmin + + @vmin.setter + def vmin(self, newval: Optional[Union[unyt_quantity, float]]) -> None: + self._reset_norm() + self._set_quan_attr("_vmin", newval) + + @property + def vmax(self) -> Optional[unyt_quantity]: + return self._vmax + + @vmax.setter + def vmax(self, newval: Optional[Union[unyt_quantity, float]]) -> None: + self._reset_norm() + self._set_quan_attr("_vmax", newval) + + @property + def norm_type(self) -> Optional[Type[Normalize]]: + return self._norm_type + + @norm_type.setter + def norm_type(self, newval: Optional[Type[Normalize]]) -> None: + if not ( + newval is None + or (isinstance(newval, type) and issubclass(newval, Normalize)) + ): + raise TypeError( + "Expected a subclass of matplotlib.colors.Normalize, " + f"received {newval} with type {type(newval)}" + ) + self._reset_norm() + if newval is not SymLogNorm: + self.linthresh = None + self._norm_type = newval + + @property + def norm(self) -> Optional[Normalize]: + return self._norm + + @norm.setter + def norm(self, newval: Normalize) -> None: + if not isinstance(newval, Normalize): + raise TypeError( + "Expected a matplotlib.colors.Normalize object, " + f"received {newval} with type {type(newval)}" + ) + self._reset_constraints() + self._norm = newval + + @property + def linthresh(self) -> Optional[float]: + return self._linthresh + + @linthresh.setter + def linthresh(self, newval: Optional[Union[unyt_quantity, float]]) -> None: + self._reset_norm() + self._set_quan_attr("_linthresh", newval) + if self._linthresh is not None and self._linthresh <= 0: + raise ValueError( + f"linthresh can only be set to stricly positive values, got {newval}" + ) + if newval is not None: + self.norm_type = SymLogNorm + + def get_norm(self, data: np.ndarray, *args, **kw) -> Normalize: + if self.has_norm: + return self.norm + + finite_values_mask = np.isfinite(data) + if self.vmin is not None: + dvmin = self.to_float(self.vmin) + elif np.any(finite_values_mask): + dvmin = self.to_float(np.nanmin(data[finite_values_mask])) + else: + dvmin = 1 * getattr(data, "units", 1) + kw.setdefault("vmin", dvmin) + + if self.vmax is not None: + dvmax = self.to_float(self.vmax) + elif np.any(finite_values_mask): + dvmax = self.to_float(np.nanmax(data[finite_values_mask])) + else: + dvmax = 1 * getattr(data, "units", 1) + kw.setdefault("vmax", dvmax) + + min_abs_val, max_abs_val = np.sort(np.abs((kw["vmin"], kw["vmax"]))) + if self.norm_type is not None: + # this is a convenience mechanism for backward compat, + # allowing to toggle between lin and log scaling without detailled user input + norm_type = self.norm_type + else: + if kw["vmin"] == kw["vmax"] or not np.any(np.isfinite(data)): + norm_type = Normalize + elif kw["vmin"] <= 0: + norm_type = SymLogNorm + elif ( + Version("3.3") <= MPL_VERSION < Version("3.5") + and kw["vmin"] == 0 + and kw["vmax"] > 0 + ): + # normally, a LogNorm scaling would still be OK here because + # LogNorm will mask 0 values when calculating vmin. But + # due to a bug in matplotlib's imshow, if the data range + # spans many orders of magnitude while containing zero points + # vmin can get rescaled to 0, resulting in an error when the image + # gets drawn. So here we switch to symlog to avoid that until + # a fix is in -- see PR #3161 and linked issue. + cutoff_sigdigs = 15 + if ( + np.log10(np.nanmax(data[np.isfinite(data)])) + - np.log10(np.nanmin(data[data > 0])) + > cutoff_sigdigs + ): + norm_type = SymLogNorm + else: + norm_type = LogNorm + else: + norm_type = LogNorm + + if norm_type is SymLogNorm: + # if cblinthresh is not specified, try to come up with a reasonable default + if self.linthresh is not None: + linthresh = self.to_float(self.linthresh) + elif min_abs_val > 0: + linthresh = min_abs_val + else: + linthresh = max_abs_val / 1000 + kw.setdefault("linthresh", linthresh) + if MPL_VERSION >= Version("3.2"): + # note that this creates an inconsistency between mpl versions + # since the default value previous to mpl 3.4.0 is np.e + # but it is only exposed since 3.2.0 + kw.setdefault("base", 10) + + return norm_type(*args, **kw) + + +class ColorbarHandler: + __slots__ = ("_draw_cbar", "_draw_minorticks", "_cmap", "_background_color") + + def __init__( + self, + *, + draw_cbar: bool = True, + draw_minorticks: bool = True, + cmap: Optional[Union[Colormap, str]] = None, + background_color: Optional[str] = "white", + ): + self._draw_cbar = draw_cbar + self._draw_minorticks = draw_minorticks + self._cmap: Optional[Colormap] = None + self.cmap = cmap + self._background_color = background_color + + @property + def draw_cbar(self) -> bool: + return self._draw_cbar + + @draw_cbar.setter + def draw_cbar(self, newval) -> None: + if not isinstance(newval, bool): + raise TypeError( + f"Excpected a boolean, got {newval} with type {type(newval)}" + ) + self._draw_cbar = newval + + @property + def draw_minorticks(self) -> bool: + return self._draw_minorticks + + @draw_minorticks.setter + def draw_minorticks(self, newval) -> None: + if not isinstance(newval, bool): + raise TypeError( + f"Excpected a boolean, got {newval} with type {type(newval)}" + ) + self._draw_minoticks = newval + + @property + def cmap(self) -> Colormap: + return self._cmap or get_cmap(ytcfg.get("yt", "default_colormap")) + + @cmap.setter + def cmap(self, newval) -> None: + if isinstance(newval, Colormap) or newval is None: + self._cmap = newval + elif isinstance(newval, str): + self._cmap = get_cmap(newval) + elif is_sequence(newval): + # tuple colormaps are from palettable (or brewer2mpl) + self._cmap = get_brewer_cmap(newval) + else: + raise TypeError( + "Expected a colormap object or name, " + f"got {newval} with type {type(newval)}" + ) + + @property + def background_color(self): + return self._background_color + + @background_color.setter + def background_color(self, newval): + if newval is None: + self._background_color = self.cmap(0) + else: + self._background_color = newval From 2841a521303e8dff8aa22779645d988c78ecf58e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sun, 20 Mar 2022 09:05:05 +0100 Subject: [PATCH 02/54] RFC: refactor base plot types and containers to utilise NormHandler and ColorbarHandler --- yt/_maintenance/backports.py | 15 + yt/utilities/exceptions.py | 4 + yt/visualization/_commons.py | 135 +++++++++ yt/visualization/_handlers.py | 3 + yt/visualization/base_plot_types.py | 270 +++++++++--------- yt/visualization/plot_container.py | 412 ++++++++++++++++------------ 6 files changed, 532 insertions(+), 307 deletions(-) diff --git a/yt/_maintenance/backports.py b/yt/_maintenance/backports.py index 7e84b15419..3f4a2ac5f5 100644 --- a/yt/_maintenance/backports.py +++ b/yt/_maintenance/backports.py @@ -80,3 +80,18 @@ def __get__(self, instance, owner=None): else: pass + + +builtin_zip = zip +if sys.version_info >= (3, 10): + zip = builtin_zip +else: + # this function is deprecated in more_itertools + # because it is superseded by the standard library + from more_itertools import zip_equal + + def zip(*args, strict=False): + if strict: + return zip_equal(*args) + else: + return builtin_zip(*args) diff --git a/yt/utilities/exceptions.py b/yt/utilities/exceptions.py index 39b8b45712..42531a1bc2 100644 --- a/yt/utilities/exceptions.py +++ b/yt/utilities/exceptions.py @@ -898,6 +898,10 @@ def __str__(self): return msg +class YTConfigurationError(YTException): + pass + + class GenerationInProgress(Exception): def __init__(self, fields): self.fields = fields diff --git a/yt/visualization/_commons.py b/yt/visualization/_commons.py index f28e79bbe4..ae4f503d53 100644 --- a/yt/visualization/_commons.py +++ b/yt/visualization/_commons.py @@ -4,9 +4,18 @@ from functools import wraps from typing import Optional, Type, TypeVar +if sys.version_info >= (3, 10): + pass +else: + from yt._maintenance.backports import zip + import matplotlib +import numpy as np +from more_itertools import always_iterable from packaging.version import Version +from yt.config import ytcfg + from ._mpl_imports import ( FigureCanvasAgg, FigureCanvasBase, @@ -201,3 +210,129 @@ def _swap_arg_pair_order(*args): new_args.append(args[x_id + 1]) new_args.append(args[x_id]) return tuple(new_args) + + +def get_log_minorticks(vmin, vmax): + """calculate positions of linear minorticks on a log colorbar + + Parameters + ---------- + vmin : float + the minimum value in the colorbar + vmax : float + the maximum value in the colorbar + + """ + expA = np.floor(np.log10(vmin)) + expB = np.floor(np.log10(vmax)) + cofA = np.ceil(vmin / 10**expA).astype("int64") + cofB = np.floor(vmax / 10**expB).astype("int64") + lmticks = [] + while cofA * 10**expA <= cofB * 10**expB: + if expA < expB: + lmticks = np.hstack((lmticks, np.linspace(cofA, 9, 10 - cofA) * 10**expA)) + cofA = 1 + expA += 1 + else: + lmticks = np.hstack( + (lmticks, np.linspace(cofA, cofB, cofB - cofA + 1) * 10**expA) + ) + expA += 1 + return np.array(lmticks) + + +def get_symlog_minorticks(linthresh: float, vmin: float, vmax: float) -> np.ndarray: + """calculate positions of linear minorticks on a symmetric log colorbar + + Parameters + ---------- + linthresh : float + the threshold for the linear region + vmin : float + the minimum value in the colorbar + vmax : float + the maximum value in the colorbar + + """ + if vmin > 0: + return get_log_minorticks(vmin, vmax) + elif vmax < 0 and vmin < 0: + return -get_log_minorticks(-vmax, -vmin) + elif vmin == 0: + return np.hstack((0, get_log_minorticks(linthresh, vmax))) + elif vmax == 0: + return np.hstack((-get_log_minorticks(linthresh, -vmin)[::-1], 0)) + else: + return np.hstack( + ( + -get_log_minorticks(linthresh, -vmin)[::-1], + 0, + get_log_minorticks(linthresh, vmax), + ) + ) + + +def get_symlog_majorticks(linthresh: float, vmin: float, vmax: float) -> np.ndarray: + if vmin >= 0.0: + yticks = [vmin] + list( + 10 + ** np.arange( + np.rint(np.log10(linthresh)), + np.ceil(np.log10(1.1 * vmax)), + ) + ) + elif vmax <= 0.0: + if MPL_VERSION >= Version("3.5.0b"): + offset = 0 + else: + offset = 1 + + yticks = list( + -( + 10 + ** np.arange( + np.floor(np.log10(-vmin)), + np.rint(np.log10(linthresh)) - offset, + -1, + ) + ) + ) + [vmax] + else: + yticks = ( + list( + -( + 10 + ** np.arange( + np.floor(np.log10(-vmin)), + np.rint(np.log10(linthresh)) - 1, + -1, + ) + ) + ) + + [0] + + list( + 10 + ** np.arange( + np.rint(np.log10(linthresh)), + np.ceil(np.log10(1.1 * vmax)), + ) + ) + ) + if yticks[-1] > vmax: + yticks.pop() + return np.array(yticks) + + +def get_default_from_config(data_source, *, field, keys, defaults): + _keys = list(always_iterable(keys)) + _defaults = list(always_iterable(defaults)) + + ftype, fname = data_source._determine_fields(field)[0] + ret = [ + ytcfg.get_most_specific("plot", ftype, fname, key, fallback=default) + for key, default in zip(_keys, _defaults, strict=True) + ] + if len(ret) == 1: + return ret[0] + else: + return ret diff --git a/yt/visualization/_handlers.py b/yt/visualization/_handlers.py index 8e966bc09b..b3392ea3fe 100644 --- a/yt/visualization/_handlers.py +++ b/yt/visualization/_handlers.py @@ -266,6 +266,9 @@ def get_norm(self, data: np.ndarray, *args, **kw) -> Normalize: if norm_type is SymLogNorm: # if cblinthresh is not specified, try to come up with a reasonable default + min_abs_val, max_abs_val = np.sort( + np.abs((self.to_float(np.nanmin(data)), self.to_float(np.nanmax(data)))) + ) if self.linthresh is not None: linthresh = self.to_float(self.linthresh) elif min_abs_val > 0: diff --git a/yt/visualization/base_plot_types.py b/yt/visualization/base_plot_types.py index 34904350cc..41f7530c35 100644 --- a/yt/visualization/base_plot_types.py +++ b/yt/visualization/base_plot_types.py @@ -1,19 +1,25 @@ -import warnings +from abc import ABC from io import BytesIO +from typing import Optional, Tuple, Union import matplotlib import numpy as np +from matplotlib.axis import Axis +from matplotlib.colors import LogNorm, Normalize, SymLogNorm +from matplotlib.figure import Figure +from matplotlib.ticker import LogFormatterMathtext from packaging.version import Version -from yt.funcs import ( - get_brewer_cmap, - get_interactivity, - is_sequence, - matplotlib_style_context, - mylog, -) +from yt.funcs import get_interactivity, is_sequence, matplotlib_style_context, mylog +from yt.visualization._handlers import ColorbarHandler, NormHandler -from ._commons import MPL_VERSION, get_canvas, validate_image_name +from ._commons import ( + MPL_VERSION, + get_canvas, + get_symlog_majorticks, + get_symlog_minorticks, + validate_image_name, +) BACKEND_SPECS = { "GTK": ["backend_gtk", "FigureCanvasGTK", "FigureManagerGTK"], @@ -77,7 +83,15 @@ def __init__(self, viewer, window_plot, frb, field, font_properties, font_color) class PlotMPL: """A base class for all yt plots made using matplotlib, that is backend independent.""" - def __init__(self, fsize, axrect, figure, axes): + def __init__( + self, + fsize, + axrect, + *, + norm_handler: NormHandler, + figure: Optional[Figure] = None, + axes: Optional[Axis] = None, + ): """Initialize PlotMPL class""" import matplotlib.figure @@ -106,6 +120,8 @@ def __init__(self, fsize, axrect, figure, axes): which="both", axis="both", direction="in", top=True, right=True ) + self.norm_handler = norm_handler + def _create_axes(self, axrect): self.axes = self.figure.add_axes(axrect) @@ -141,9 +157,7 @@ def save(self, name, mpl_kwargs=None, canvas=None): if mpl_kwargs is None: mpl_kwargs = {} - if "papertype" not in mpl_kwargs and Version(matplotlib.__version__) < Version( - "3.3.0" - ): + if "papertype" not in mpl_kwargs and MPL_VERSION < Version("3.3.0"): mpl_kwargs["papertype"] = "auto" name = validate_image_name(name) @@ -195,13 +209,38 @@ def _repr_png_(self): return f.read() -class ImagePlotMPL(PlotMPL): +class ImagePlotMPL(PlotMPL, ABC): """A base class for yt plots made using imshow""" - def __init__(self, fsize, axrect, caxrect, zlim, figure, axes, cax): + _default_font_size = 18.0 + + def __init__( + self, + fsize=None, + axrect=None, + caxrect=None, + *, + norm_handler: NormHandler, + colorbar_handler: ColorbarHandler, + figure: Optional[Figure] = None, + axes: Optional[Axis] = None, + cax: Optional[Axis] = None, + ): """Initialize ImagePlotMPL class object""" - super().__init__(fsize, axrect, figure, axes) - self.zmin, self.zmax = zlim + self.colorbar_handler = colorbar_handler + _missing_layout_specs = [_ is None for _ in (fsize, axrect, caxrect)] + + if all(_missing_layout_specs): + fsize, axrect, caxrect = self._get_best_layout() + elif any(_missing_layout_specs): + raise TypeError( + "ImagePlotMPL cannot be initialized with partially specified layout." + ) + + super().__init__( + fsize, axrect, norm_handler=norm_handler, figure=figure, axes=axes + ) + if cax is None: self.cax = self.figure.add_axes(caxrect) else: @@ -209,12 +248,39 @@ def __init__(self, fsize, axrect, caxrect, zlim, figure, axes, cax): cax.set_position(caxrect) self.cax = cax - def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect): + def _setup_layout_constraints( + self, figure_size: Union[Tuple[float, float], float], fontsize: float + ): + # Setup base layout attributes + # derived classes need to call this before super().__init__ + # but they are free to do other stuff in between + + if isinstance(figure_size, tuple): + assert len(figure_size) == 2 + assert all(isinstance(_, float) for _ in figure_size) + self._figure_size = figure_size + else: + assert isinstance(figure_size, float) + self._figure_size = (figure_size, figure_size) + + self._draw_axes = True + fontscale = float(fontsize) / self.__class__._default_font_size + if fontscale < 1.0: + fontscale = np.sqrt(fontscale) + + self._cb_size = 0.0375 * self._figure_size[0] + self._ax_text_size = [1.2 * fontscale, 0.9 * fontscale] + self._top_buff_size = 0.30 * fontscale + self._aspect = 1.0 + + def _reset_layout(self) -> None: + size, axrect, caxrect = self._get_best_layout() + self.axes.set_position(axrect) + self.cax.set_position(caxrect) + self.figure.set_size_inches(*size) + + def _init_image(self, data, extent, aspect): """Store output of imshow in image variable""" - cbnorm_kwargs = dict( - vmin=float(self.zmin) if self.zmin is not None else None, - vmax=float(self.zmax) if self.zmax is not None else None, - ) if MPL_VERSION < Version("3.2"): # with MPL 3.1 we use np.inf as a mask instead of np.nan @@ -224,50 +290,8 @@ def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect): # see https://github.com/yt-project/yt/pull/2517 and https://github.com/yt-project/yt/pull/3793 data[~np.isfinite(data)] = np.nan - zmin = float(self.zmin) if self.zmin is not None else np.nanmin(data) - zmax = float(self.zmax) if self.zmax is not None else np.nanmax(data) - - if cbnorm == "symlog": - # if cblinthresh is not specified, try to come up with a reasonable default - min_abs_val, max_abs_val = np.sort( - np.abs((np.nanmin(data), np.nanmax(data))) - ) - if cblinthresh is not None: - if zmin * zmax > 0 and cblinthresh < min_abs_val: - # see https://github.com/yt-project/yt/issues/3564 - warnings.warn( - f"Cannot set a symlog norm with linear threshold {cblinthresh} " - f"lower than the minimal absolute data value {min_abs_val} . " - "Switching to log norm." - ) - cbnorm = "log10" - elif min_abs_val > 0: - cblinthresh = min_abs_val - else: - cblinthresh = max_abs_val / 1000 - - if cbnorm == "log10": - cbnorm_cls = matplotlib.colors.LogNorm - elif cbnorm == "linear": - cbnorm_cls = matplotlib.colors.Normalize - elif cbnorm == "symlog": - cbnorm_kwargs.update(dict(linthresh=cblinthresh)) - if MPL_VERSION >= Version("3.2.0"): - # note that this creates an inconsistency between mpl versions - # since the default value previous to mpl 3.4.0 is np.e - # but it is only exposed since 3.2.0 - cbnorm_kwargs["base"] = 10 - - cbnorm_cls = matplotlib.colors.SymLogNorm - else: - raise ValueError(f"Unknown value `cbnorm` == {cbnorm}") - - norm = cbnorm_cls(**cbnorm_kwargs) - + norm = self.norm_handler.get_norm(data) extent = [float(e) for e in extent] - # tuple colormaps are from palettable (or brewer2mpl) - if isinstance(cmap, tuple): - cmap = get_brewer_cmap(cmap) if self._transform is None: # sets the transform to be an ax.TransData object, where the @@ -310,67 +334,68 @@ def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect): extent=extent, norm=norm, aspect=aspect, - cmap=cmap, + cmap=self.colorbar_handler.cmap, interpolation="nearest", transform=transform, ) - if cbnorm == "symlog": - formatter = matplotlib.ticker.LogFormatterMathtext(linthresh=cblinthresh) - self.cb = self.figure.colorbar(self.image, self.cax, format=formatter) + self._set_axes(norm) - if zmin >= 0.0: - yticks = [zmin] + list( - 10 - ** np.arange( - np.rint(np.log10(cblinthresh)), - np.ceil(np.log10(1.1 * zmax)), - ) - ) - elif zmax <= 0.0: - if MPL_VERSION >= Version("3.5.0b"): - offset = 0 - else: - offset = 1 - - yticks = list( - -( - 10 - ** np.arange( - np.floor(np.log10(-zmin)), - np.rint(np.log10(cblinthresh)) - offset, - -1, - ) - ) - ) + [zmax] - else: - yticks = ( - list( - -( - 10 - ** np.arange( - np.floor(np.log10(-zmin)), - np.rint(np.log10(cblinthresh)) - 1, - -1, - ) - ) - ) - + [0] - + list( - 10 - ** np.arange( - np.rint(np.log10(cblinthresh)), - np.ceil(np.log10(1.1 * zmax)), - ) - ) + def _set_axes(self, norm: Normalize) -> None: + if isinstance(norm, SymLogNorm): + formatter = LogFormatterMathtext(linthresh=norm.linthresh) + self.cb = self.figure.colorbar(self.image, self.cax, format=formatter) + self.cb.set_ticks( + get_symlog_majorticks( + linthresh=norm.linthresh, vmin=norm.vmin, vmax=norm.vmax ) - if yticks[-1] > zmax: - yticks.pop() - self.cb.set_ticks(yticks) + ) else: self.cb = self.figure.colorbar(self.image, self.cax) self.cax.tick_params(which="both", axis="y", direction="in") + fmt_kwargs = dict(style="scientific", scilimits=(-2, 3), useMathText=True) + self.image.axes.ticklabel_format(**fmt_kwargs) + if type(norm) not in (LogNorm, SymLogNorm): + self.cb.ax.ticklabel_format(**fmt_kwargs) + if self.colorbar_handler.draw_minorticks: + if isinstance(norm, SymLogNorm): + if Version("3.2.0") <= MPL_VERSION < Version("3.5.0b"): + # no known working method to draw symlog minor ticks + # see https://github.com/yt-project/yt/issues/3535 + pass + else: + flinthresh = 10 ** np.floor(np.log10(norm.linthresh)) + absmax = np.abs((norm.vmin, norm.vmax)).max() + if (absmax - flinthresh) / absmax < 0.1: + flinthresh /= 10 + mticks = get_symlog_minorticks(flinthresh, norm.vmin, norm.vmax) + if MPL_VERSION < Version("3.5.0b"): + # https://github.com/matplotlib/matplotlib/issues/21258 + mticks = self.image.norm(mticks) + self.cax.yaxis.set_ticks(mticks, minor=True) + + elif isinstance(norm, LogNorm): + self.cax.minorticks_on() + self.cax.xaxis.set_visible(False) + + else: + self.cax.minorticks_on() + else: + self.cax.minorticks_off() + + self.image.axes.set_facecolor(self.colorbar_handler.background_color) + def _get_best_layout(self): + # this method is called in ImagePlotMPL.__init__ + # required attributes + # - self._figure_size: Union[float, Tuple[float, float]] + # - self._aspect: float + # - self._ax_text_size: Tuple[float, float] + # - self._draw_axes: bool + # - self.colorbar_handler: ColorbarHandler + + # optional attribtues + # - self._unit_aspect: float # Ensure the figure size along the long axis is always equal to _figure_size unit_aspect = getattr(self, "_unit_aspect", 1) @@ -385,7 +410,7 @@ def _get_best_layout(self): else: y_fig_size /= scaling - if self._draw_colorbar: + if self.colorbar_handler.draw_cbar: cb_size = self._cb_size cb_text_size = self._ax_text_size[1] + 0.45 else: @@ -401,7 +426,7 @@ def _get_best_layout(self): top_buff_size = self._top_buff_size - if not self._draw_axes and not self._draw_colorbar: + if not self._draw_axes and not self.colorbar_handler.draw_cbar: x_axis_size = 0.0 y_axis_size = 0.0 cb_size = 0.0 @@ -458,18 +483,15 @@ def _toggle_axes(self, choice, draw_frame=None): self.axes.set_frame_on(draw_frame) self.axes.get_xaxis().set_visible(choice) self.axes.get_yaxis().set_visible(choice) - size, axrect, caxrect = self._get_best_layout() - self.axes.set_position(axrect) - self.cax.set_position(caxrect) - self.figure.set_size_inches(*size) + self._reset_layout() - def _toggle_colorbar(self, choice): + def _toggle_colorbar(self, choice: bool): """ Turn on/off displaying the colorbar for a plot choice = True or False """ - self._draw_colorbar = choice + self.colorbar_handler.draw_cbar = choice self.cax.set_visible(choice) size, axrect, caxrect = self._get_best_layout() self.axes.set_position(axrect) diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index 2a310ef7c8..3b3fbbe81a 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -6,20 +6,24 @@ import warnings from collections import defaultdict from functools import wraps -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union +import matplotlib import numpy as np -from matplotlib.cm import get_cmap +from matplotlib.colors import LogNorm, Normalize, SymLogNorm from matplotlib.font_manager import FontProperties -from more_itertools.more import always_iterable +from unyt.dimensions import length +from yt._maintenance.deprecation import issue_deprecation_warning from yt.config import ytcfg from yt.data_objects.time_series import DatasetSeries -from yt.funcs import dictWithFactory, ensure_dir, is_sequence, iter_fields, mylog -from yt.units import YTQuantity +from yt.funcs import ensure_dir, is_sequence, iter_fields from yt.units.unit_object import Unit # type: ignore from yt.utilities.definitions import formatted_length_unit_names -from yt.utilities.exceptions import YTNotInsideNotebook +from yt.utilities.exceptions import YTConfigurationError, YTNotInsideNotebook +from yt.visualization._commons import get_default_from_config +from yt.visualization._handlers import ColorbarHandler, NormHandler +from yt.visualization.base_plot_types import PlotMPL from ._commons import ( DEFAULT_FONT_PROPERTIES, @@ -64,66 +68,6 @@ def newfunc(self, field, *args, **kwargs): return newfunc -def get_log_minorticks(vmin, vmax): - """calculate positions of linear minorticks on a log colorbar - - Parameters - ---------- - vmin : float - the minimum value in the colorbar - vmax : float - the maximum value in the colorbar - - """ - expA = np.floor(np.log10(vmin)) - expB = np.floor(np.log10(vmax)) - cofA = np.ceil(vmin / 10**expA).astype("int64") - cofB = np.floor(vmax / 10**expB).astype("int64") - lmticks = [] - while cofA * 10**expA <= cofB * 10**expB: - if expA < expB: - lmticks = np.hstack((lmticks, np.linspace(cofA, 9, 10 - cofA) * 10**expA)) - cofA = 1 - expA += 1 - else: - lmticks = np.hstack( - (lmticks, np.linspace(cofA, cofB, cofB - cofA + 1) * 10**expA) - ) - expA += 1 - return np.array(lmticks) - - -def get_symlog_minorticks(linthresh, vmin, vmax): - """calculate positions of linear minorticks on a symmetric log colorbar - - Parameters - ---------- - linthresh : float - the threshold for the linear region - vmin : float - the minimum value in the colorbar - vmax : float - the maximum value in the colorbar - - """ - if vmin > 0: - return get_log_minorticks(vmin, vmax) - elif vmax < 0 and vmin < 0: - return -get_log_minorticks(-vmax, -vmin) - elif vmin == 0: - return np.hstack((0, get_log_minorticks(linthresh, vmax))) - elif vmax == 0: - return np.hstack((-get_log_minorticks(linthresh, -vmin)[::-1], 0)) - else: - return np.hstack( - ( - -get_log_minorticks(linthresh, -vmin)[::-1], - 0, - get_log_minorticks(linthresh, vmax), - ) - ) - - field_transforms = {} @@ -166,66 +110,44 @@ def __init__(self, data_source, default_factory=None): class PlotContainer(abc.ABC): """A container for generic plots""" + _plot_dict_type: Type[PlotDictionary] = PlotDictionary _plot_type: Optional[str] = None _plot_valid = False - # Plot defaults - _colormap_config: dict - _log_config: dict - _units_config: dict + _default_figure_size = tuple(matplotlib.rcParams["figure.figsize"]) + _default_font_size = 14.0 - def __init__(self, data_source, figure_size, fontsize): + def __init__(self, data_source, figure_size=None, fontsize: float = None): self.data_source = data_source self.ds = data_source.ds self.ts = self._initialize_dataset(self.ds) - if is_sequence(figure_size): - self.figure_size = float(figure_size[0]), float(figure_size[1]) - else: - self.figure_size = float(figure_size) + self.plots = self.__class__._plot_dict_type(data_source) + + self._set_figure_size(figure_size) + if fontsize is None: + fontsize = self.__class__._default_font_size if sys.version_info >= (3, 9): font_dict = DEFAULT_FONT_PROPERTIES | {"size": fontsize} else: - font_dict = {**DEFAULT_FONT_PROPERTIES, "size": fontsize} + font_dict = {**DEFAULT_FONT_PROPERTIES, "size": fontsize} # type:ignore self._font_properties = FontProperties(**font_dict) self._font_color = None self._xlabel = None self._ylabel = None - self._minorticks = {} - self._field_transform = {} - - self.setup_defaults() - - def setup_defaults(self): - def default_from_config(keys, defaults): - _keys = list(always_iterable(keys)) - _defaults = list(always_iterable(defaults)) - - def getter(field): - ftype, fname = self.data_source._determine_fields(field)[0] - ret = [ - ytcfg.get_most_specific("plot", ftype, fname, key, fallback=default) - for key, default in zip(_keys, _defaults) - ] - if len(ret) == 1: - return ret[0] - return ret - - return getter - - default_cmap = ytcfg.get("yt", "default_colormap") - self._colormap_config = dictWithFactory( - default_from_config("cmap", default_cmap) - )() - self._log_config = dictWithFactory( - default_from_config(["log", "linthresh"], [None, None]) - )() - self._units_config = dictWithFactory(default_from_config("units", [None]))() + self._minorticks: Dict[Tuple[str, str], bool] = {} @accepts_all_fields @invalidate_plot - def set_log(self, field, log, linthresh=None, symlog_auto=False): + def set_log( + self, + field, + log: Optional[bool] = None, + *, + linthresh: Optional[Union[float, str]] = None, + symlog_auto: Optional[bool] = None, # deprecated + ): """set a field to log, linear, or symlog. Symlog scaling is a combination of linear and log, where from 0 to a @@ -240,30 +162,56 @@ def set_log(self, field, log, linthresh=None, symlog_auto=False): field : string the field to set a transform if field == 'all', applies to all plots. - log : boolean - Log on/off: on means log scaling; off means linear scaling. Unless - a linthresh is set or symlog_auto is set in which case symlog is used. - linthresh : float, optional + log : boolean, optional + Log on/off: on means log scaling; off means linear scaling. + linthresh : float, or 'auto', optional when using symlog scaling, linthresh is the value at which scaling transitions from linear to logarithmic. linthresh must be positive. Note: setting linthresh will automatically enable symlog scale - symlog_auto : boolean - if symlog_auto is True, then yt will use symlog scaling and attempt to - determine a linthresh automatically. Setting a linthresh manually - overrides this value. + Note that *log* and *linthresh* are mutually exclusive arguments """ - if symlog_auto: - self._field_transform[field] = symlog_transform - if log: - self._field_transform[field] = log_transform - else: - self._field_transform[field] = linear_transform + if log is None and linthresh is None and symlog_auto is None: + raise TypeError("set_log requires log or linthresh be set") + + if symlog_auto is not None: + issue_deprecation_warning( + "the symlog_auto argument is deprecated. Use linthresh='auto' instead", + since="4.1", + ) + if symlog_auto is True: + linthresh = "auto" + elif symlog_auto is False: + pass + else: + raise TypeError( + "Received invalid value for parameter symlog_auto. " + f"Expected a boolean, got {symlog_auto!r}" + ) + + pnh = self.plots[field].norm_handler + if linthresh is not None: - if not linthresh > 0.0: - raise ValueError('"linthresh" must be positive') - self._field_transform[field] = symlog_transform - self._field_transform[field].func = linthresh + if isinstance(linthresh, str): + if linthresh == "auto": + pnh.norm_type = SymLogNorm + else: + raise ValueError( + "Expected a number, a unyt_quantity, a (float, 'unit') tuple, or 'auto'. " + f"Got linthresh={linthresh!r}" + ) + else: + # pnh takes care of switching to symlog when linthresh is set + pnh.linthresh = linthresh + elif log is True: + pnh.norm_type = LogNorm + elif log is False: + pnh.norm_type = Normalize + else: + raise TypeError( + f"Could not parse arguments log={log!r}, linthresh={linthresh!r}" + ) + return self def get_log(self, field): @@ -278,21 +226,68 @@ def get_log(self, field): """ # devnote : accepts_all_fields decorator is not applicable here because # the return variable isn't self + issue_deprecation_warning( + "The get_log method is not reliable and is deprecated. " + "Please do not rely on it.", + since="4.1", + ) log = {} if field == "all": fields = list(self.plots.keys()) else: fields = field for field in self.data_source._determine_fields(fields): - log[field] = self._field_transform[field] == log_transform + pnh = self.plots[field].norm_handler + if pnh.norm is not None: + log[field] = type(pnh.norm) is LogNorm + elif pnh.norm_type is not None: + log[field] = pnh.norm_type is LogNorm + else: + # the NormHandler object has no constraints yet + # so we'll assume defaults + log[field] = True return log @invalidate_plot - def set_transform(self, field, name): + def set_transform(self, field, name: str): field = self.data_source._determine_fields(field)[0] - if name not in field_transforms: - raise KeyError(name) - self._field_transform[field] = field_transforms[name] + pnh = self.plots[field].norm_handler + pnh.norm_type = { + "linear": Normalize, + "log10": LogNorm, + "symlog": SymLogNorm, + }[name] + return self + + @accepts_all_fields + @invalidate_plot + def set_norm(self, field, norm: Normalize): + r""" + Set a custom ``matplotlib.colors.Normalize`` to plot *field*. + + Any constraints previously set with `set.log`, `set.zlim` will be + dropped. + + Note that any float value attached to *norm* (e.g. vmin, vmax, + vcenter ...) will be read in the current displayed units, which can be + controlled with the `set_units` method. + + Parameters + ---------- + field : str or tuple[str, str] + if field == 'all', applies to all plots. + norm : matplotlib.colors.Normalize + see https://matplotlib.org/stable/tutorials/colors/colormapnorms.html + """ + + if field == "all": + fields = list(self.plots.keys()) + else: + fields = field + + for field in self.data_source._determine_fields(fields): + pnh = self.plots[field].norm_handler + pnh.norm = norm return self @accepts_all_fields @@ -315,6 +310,7 @@ def set_minorticks(self, field, state): self._minorticks[field] = state return self + @abc.abstractmethod def _setup_plots(self): # Left blank to be overridden in subclasses pass @@ -361,7 +357,6 @@ def _switch_ds(self, new_ds, data_source=None): lim = tuple(new_ds.quan(l.value, str(l.units)) for l in lim) setattr(self, lim_name, lim) self.plots.data_source = new_object - self._background_color.data_source = new_object self._colorbar_label.data_source = new_object self._setup_plots() @@ -458,6 +453,16 @@ def set_font_size(self, size): """ return self.set_font({"size": size}) + def _set_figure_size(self, size): + if size is None: + self.figure_size = self.__class__._default_figure_size + elif is_sequence(size): + if len(size) != 2: + raise TypeError(f"Expected a single float or a pair, got {size}") + self.figure_size = float(size[0]), float(size[1]) + else: + self.figure_size = float(size) + @invalidate_plot @invalidate_figure def set_figure_size(self, size): @@ -465,11 +470,13 @@ def set_figure_size(self, size): parameters ---------- - size : float - The size of the figure on the longest axis (in units of inches), - including the margins but not the colorbar. + size : float, a sequence of two floats, or None + The size of the figure (in units of inches), including the margins + but not the colorbar. If a single float is passed, it's interpreted + as the size along the long axis. + Pass None to reset """ - self.figure_size = float(size) + self._set_figure_size(size) return self @validate_plot @@ -531,8 +538,7 @@ def save( new_name = validate_image_name(name, suffix) if new_name == name: - # somehow mypy thinks we may not have a plots attr yet, hence we turn it off here - for v in self.plots.values(): # type: ignore + for v in self.plots.values(): out_name = v.save(name, mpl_kwargs) names.append(out_name) return names @@ -554,7 +560,7 @@ def save( plot_type = "OffAxisSlice" # somehow mypy thinks we may not have a plots attr yet, hence we turn it off here - for k, v in self.plots.items(): # type: ignore + for k, v in self.plots.items(): if isinstance(k, tuple): k = k[1] @@ -857,19 +863,58 @@ def show_axes(self, field=None): return self -class ImagePlotContainer(PlotContainer): +class ImagePlotContainer(PlotContainer, abc.ABC): """A container for plots with colorbars.""" _colorbar_valid = False def __init__(self, data_source, figure_size, fontsize): super().__init__(data_source, figure_size, fontsize) - self.plots = PlotDictionary(data_source) self._callbacks = [] - self._cbar_minorticks = {} - self._background_color = PlotDictionary(self.data_source, lambda: "w") self._colorbar_label = PlotDictionary(self.data_source, lambda: None) + def _get_default_handlers( + self, field, default_display_units: Unit + ) -> Tuple[NormHandler, ColorbarHandler]: + + usr_units_str = get_default_from_config( + self.data_source, field=field, keys="units", defaults=[None] + ) + if usr_units_str is not None: + usr_units = Unit(usr_units_str) + d1 = usr_units.dimensions + d2 = default_display_units.dimensions + + if d1 == d2: + display_units = usr_units + elif getattr(self, "projected", False) and d2 / d1 == length: + path_length_units = Unit( + ytcfg.get_most_specific( + "plot", *field, "path_length_units", fallback="cm" + ), + registry=self.data_source.ds.unit_registry, + ) + display_units = usr_units * path_length_units + else: + raise YTConfigurationError( + f"Invalid units in configuration file for field {field!r}. " + f"Found {usr_units!r}" + ) + else: + display_units = default_display_units + + pnh = NormHandler(self.data_source, display_units=display_units) + + cbh = ColorbarHandler( + cmap=get_default_from_config( + self.data_source, + field=field, + keys="cmap", + defaults=[None], + ) + ) + return pnh, cbh + @accepts_all_fields @invalidate_plot def set_cmap(self, field, cmap): @@ -888,7 +933,7 @@ def set_cmap(self, field, cmap): """ self._colorbar_valid = False - self._colormap_config[field] = cmap + self.plots[field].colorbar_handler.cmap = cmap return self @accepts_all_fields @@ -907,17 +952,13 @@ def set_background_color(self, field, color=None): the color map """ - if color is None: - cmap = self._colormap_config[field] - if isinstance(cmap, str): - cmap = get_cmap(cmap) - color = cmap(0) - self._background_color[field] = color + cbh = self[field].colorbar_handler + cbh.background_color = color return self @accepts_all_fields @invalidate_plot - def set_zlim(self, field, zmin, zmax, dynamic_range=None): + def set_zlim(self, field, zmin=None, zmax=None, dynamic_range=None): """set the scale of the colormap Parameters @@ -943,45 +984,26 @@ def set_zlim(self, field, zmin, zmax, dynamic_range=None): """ - def _sanitize_units(z, _field): - # convert dimensionful inputs to float - if isinstance(z, tuple): - z = self.ds.quan(*z) - if isinstance(z, YTQuantity): - try: - plot_units = self.frb[_field].units - z = z.to(plot_units).value - except AttributeError: - # only certain subclasses have a frb attribute - # they can rely on for inspecting units - mylog.warning( - "%s class doesn't support zmin/zmax" - " as tuples or unyt_quantitiy", - self.__class__.__name__, - ) - z = z.value - return z - if field == "all": fields = list(self.plots.keys()) else: fields = field + if zmin is None and zmax is None: + raise TypeError("Missing required argument zmin or zmax") for field in self.data_source._determine_fields(fields): - myzmin = _sanitize_units(zmin, field) - myzmax = _sanitize_units(zmax, field) - if zmin == "min": - myzmin = self.plots[field].image._A.min() - if zmax == "max": - myzmax = self.plots[field].image._A.max() if dynamic_range is not None: - if zmax is None: - myzmax = myzmin * dynamic_range + if zmax is None and zmin is not None: + zmax = zmin * dynamic_range + elif zmin is None and zmax is not None: + zmin = zmax / dynamic_range else: - myzmin = myzmax / dynamic_range - if myzmin > 0.0 and self._field_transform[field] == symlog_transform: - self._field_transform[field] = log_transform - self.plots[field].zmin = myzmin - self.plots[field].zmax = myzmax + raise TypeError( + "Using dynamic_range requires that either zmin or zmax " + "be specified, but not both." + ) + pnh = self.plots[field].norm_handler + pnh.vmin = zmin + pnh.vmax = zmax return self @accepts_all_fields @@ -1000,7 +1022,7 @@ def set_colorbar_minorticks(self, field, state): state : bool the state indicating 'on' (True) or 'off' (False) """ - self._cbar_minorticks[field] = state + self.plots[field].colormap_handler.draw_minorticks = state return self @invalidate_plot @@ -1020,8 +1042,32 @@ def set_colorbar_label(self, field, label): ... ) """ + field = self.data_source._determine_fields(field) self._colorbar_label[field] = label return self def _get_axes_labels(self, field): return (self._xlabel, self._ylabel, self._colorbar_label[field]) + + +class BaseLinePlot(PlotContainer, abc.ABC): + + # A common ancestor to LinePlot and ProfilePlot + + @abc.abstractmethod + def _get_axrect(self): + pass + + def _get_plot_instance(self, field): + if field in self.plots: + return self.plots[field] + axrect = self._get_axrect() + + pnh = NormHandler(self.data_source, display_units=self.data_source[field].units) + finfo = self.data_source.ds._get_field_info(*field) + if not finfo.take_log: + pnh.norm_type = Normalize + plot = PlotMPL(self.figure_size, axrect, norm_handler=pnh) + self.plots[field] = plot + + return plot From a57f1e82029bb9d4c2096f347ca39bd5b5c3cb5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sun, 20 Mar 2022 09:09:21 +0100 Subject: [PATCH 03/54] RFC: refactor plot window to utilise NormHandler and ColorbarHandler --- yt/visualization/plot_window.py | 303 ++++++++------------------------ 1 file changed, 69 insertions(+), 234 deletions(-) diff --git a/yt/visualization/plot_window.py b/yt/visualization/plot_window.py index 087ef1cbb8..709c2e1ade 100644 --- a/yt/visualization/plot_window.py +++ b/yt/visualization/plot_window.py @@ -7,14 +7,13 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np +from matplotlib.colors import Normalize from more_itertools import always_iterable from mpl_toolkits.axes_grid1 import ImageGrid -from packaging.version import Version from pyparsing import ParseFatalException from unyt.exceptions import UnitConversionError from yt._maintenance.deprecation import issue_deprecation_warning -from yt.config import ytcfg from yt.data_objects.image_array import ImageArray from yt.frontends.ytdata.data_structures import YTSpatialPlotDataset from yt.funcs import fix_axis, fix_unitary, is_sequence, iter_fields, mylog, obj_length @@ -31,8 +30,9 @@ ) from yt.utilities.math_utils import ortho_find from yt.utilities.orientation import Orientation +from yt.visualization._handlers import ColorbarHandler, NormHandler -from ._commons import MPL_VERSION, _swap_axes_extents +from ._commons import _swap_axes_extents, get_default_from_config from .base_plot_types import CallbackWrapper, ImagePlotMPL from .fixed_resolution import ( FixedResolutionBuffer, @@ -42,30 +42,18 @@ from .plot_container import ( ImagePlotContainer, apply_callback, - get_log_minorticks, - get_symlog_minorticks, invalidate_data, invalidate_figure, invalidate_plot, - linear_transform, - log_transform, - symlog_transform, ) from .plot_modifications import callback_registry import sys # isort: skip -if sys.version_info < (3, 10): - # this function is deprecated in more_itertools - # because it is superseded by the standard library - from more_itertools import zip_equal +if sys.version_info >= (3, 10): + pass else: - - def zip_equal(*args): - # FUTURE: when only Python 3.10+ is supported, - # drop this conditional and call the builtin zip - # function directly where due - return zip(*args, strict=True) + from yt._maintenance.backports import zip def get_window_parameters(axis, center, width, ds): @@ -143,7 +131,7 @@ def validate_mesh_fields(data_source, fields): raise YTInvalidFieldType(invalid_fields) -class PlotWindow(ImagePlotContainer): +class PlotWindow(ImagePlotContainer, abc.ABC): r""" A plotting mechanism based around the concept of a window into a data source. It can have arbitrary fields, each of which will be @@ -256,22 +244,31 @@ def __init__( self._projection = get_mpl_transform(projection) self._transform = get_mpl_transform(transform) + self.setup_callbacks() + self._setup_plots() + for field in self.data_source._determine_fields(self.fields): finfo = self.data_source.ds._get_field_info(*field) - if finfo.take_log: - self._field_transform[field] = log_transform + pnh = self.plots[field].norm_handler + if finfo.take_log is False: + # take_log can be `None` so we explicitly compare against a boolean + pnh.norm_type = Normalize else: - self._field_transform[field] = linear_transform - - log, linthresh = self._log_config[field] - if log is not None: - self.set_log(field, log, linthresh=linthresh) - - # Access the dictionary to force the key to be created - self._units_config[field] + # do nothing, the norm handler is responsible for + # determining a viable norm, and defaults to LogNorm/SymLogNorm + pass - self.setup_callbacks() - self._setup_plots() + # override from user configuration if any + log, linthresh = get_default_from_config( + self.data_source, + field=field, + keys=["log", "linthresh"], + defaults=[None, None], + ) + if linthresh is not None: + self.set_log(field, linthresh=linthresh) + elif log is not None: + self.set_log(field, log) def __iter__(self): for ds in self.ts: @@ -316,6 +313,7 @@ def _recreate_frb(self): old_filters = self._frb._filters # Set the bounds if hasattr(self, "zlim"): + # Support OffAxisProjectionPlot and OffAxisSlicePlot bounds = self.xlim + self.ylim + self.zlim else: bounds = self.xlim + self.ylim @@ -331,50 +329,7 @@ def _recreate_frb(self): ) # At this point the frb has the valid bounds, size, aliasing, etc. - if old_fields is None: - self._frb._get_data_source_fields() - - # New frb, apply default units (if any) - for field, field_unit in self._units_config.items(): - if field_unit is None: - continue - - field_unit = Unit(field_unit, registry=self.ds.unit_registry) - is_projected = getattr(self, "projected", False) - if is_projected: - # Obtain config - path_length_units = Unit( - ytcfg.get_most_specific( - "plot", *field, "path_length_units", fallback="cm" - ), - registry=self.ds.unit_registry, - ) - units = field_unit * path_length_units - else: - units = field_unit - try: - self.frb[field].convert_to_units(units) - except UnitConversionError: - msg = ( - "Could not apply default units from configuration.\n" - "Tried converting projected field %s from %s to %s, retaining units %s:\n" - "\tgot units for field: %s" - ) - args = [ - field, - self.frb[field].units, - units, - field_unit, - units, - ] - if is_projected: - msg += "\n\tgot units for integration length: %s" - args += [path_length_units] - - msg += "\nCheck your configuration file." - - mylog.error(msg, *args) - else: + if old_fields is not None: # Restore the old fields for key, units in zip(old_fields, old_units): self._frb[key] @@ -504,7 +459,6 @@ def set_unit(self, field, new_unit, equivalency=None, equivalency_kwargs=None): The name of the field that is to be changed. new_unit : string or Unit object - The name of the new unit. equivalency : string, optional If set, the equivalency to use to convert the current units to @@ -515,9 +469,11 @@ def set_unit(self, field, new_unit, equivalency=None, equivalency_kwargs=None): Keyword arguments to be passed to the equivalency. Only used if ``equivalency`` is set. """ - for f, u in zip_equal(iter_fields(field), always_iterable(new_unit)): + for f, u in zip(iter_fields(field), always_iterable(new_unit), strict=True): self.frb.set_unit(f, u, equivalency, equivalency_kwargs) self._equivalencies[f] = (equivalency, equivalency_kwargs) + pnh = self.plots[f].norm_handler + pnh.display_units = u return self @invalidate_plot @@ -683,6 +639,7 @@ def _set_window(self, bounds): self.xlim = tuple(bounds[0:2]) self.ylim = tuple(bounds[2:4]) if len(bounds) == 6: + # Support OffAxisProjectionPlot and OffAxisSlicePlot self.zlim = tuple(bounds[4:6]) mylog.info("xlim = %f %f", self.xlim[0], self.xlim[1]) mylog.info("ylim = %f %f", self.ylim[0], self.ylim[1]) @@ -1113,72 +1070,28 @@ def _setup_plots(self): extent = [*extentx, *extenty] + image = self.frb[f] + font_size = self._font_properties.get_size() + if f in self.plots.keys(): - zlim = (self.plots[f].zmin, self.plots[f].zmax) + pnh = self.plots[f].norm_handler + cbh = self.plots[f].colorbar_handler else: - zlim = (None, None) - - image = self.frb[f] - if self._field_transform[f] == log_transform: - msg = None - use_symlog = False - if zlim != (None, None): - pass - elif np.nanmax(image) == np.nanmin(image): - msg = f"Plotting {f}: All values = {np.nanmax(image)}" - elif np.nanmax(image) <= 0: - msg = ( - f"Plotting {f}: All negative values. Max = {np.nanmax(image)}." - ) - use_symlog = True - elif not np.any(np.isfinite(image)): - msg = f"Plotting {f}: All values = NaN." - elif np.nanmax(image) > 0.0 and np.nanmin(image) <= 0: - msg = ( - f"Plotting {f}: Both positive and negative values. " - f"Min = {np.nanmin(image)}, Max = {np.nanmax(image)}." + pnh, cbh = self._get_default_handlers( + field=f, default_display_units=image.units + ) + if pnh.display_units != image.units: + equivalency, equivalency_kwargs = self._equivalencies[f] + image.convert_to_units( + pnh.display_units, equivalency, **equivalency_kwargs ) - use_symlog = True - elif ( - (Version("3.3") <= MPL_VERSION < Version("3.5")) - and np.nanmax(image) > 0.0 - and np.nanmin(image) == 0 - ): - # normally, a LogNorm scaling would still be OK here because - # LogNorm will mask 0 values when calculating vmin. But - # due to a bug in matplotlib's imshow, if the data range - # spans many orders of magnitude while containing zero points - # vmin can get rescaled to 0, resulting in an error when the image - # gets drawn. So here we switch to symlog to avoid that until - # a fix is in -- see PR #3161 and linked issue. - cutoff_sigdigs = 15 - if ( - np.log10(np.nanmax(image[np.isfinite(image)])) - - np.log10(np.nanmin(image[image > 0])) - > cutoff_sigdigs - ): - msg = f"Plotting {f}: Wide range and zeros." - use_symlog = True - if msg is not None: - mylog.warning(msg) - if use_symlog: - mylog.warning("Switching to symlog colorbar scaling.") - self._field_transform[f] = symlog_transform - self._field_transform[f].func = None - else: - mylog.warning("Switching to linear colorbar scaling.") - self._field_transform[f] = linear_transform - - font_size = self._font_properties.get_size() fig = None axes = None cax = None - draw_colorbar = True draw_axes = True draw_frame = draw_axes if f in self.plots: - draw_colorbar = self.plots[f]._draw_colorbar draw_axes = self.plots[f]._draw_axes draw_frame = self.plots[f]._draw_frame if self.plots[f].figure is not None: @@ -1210,11 +1123,7 @@ def _setup_plots(self): self.plots[f] = WindowPlotMPL( ia, - self._field_transform[f].name, - self._field_transform[f].func, - self._colormap_config[f], extent, - zlim, self.figure_size, font_size, aspect, @@ -1223,6 +1132,8 @@ def _setup_plots(self): cax, self._projection, self._transform, + norm_handler=pnh, + colorbar_handler=cbh, ) axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y) @@ -1280,10 +1191,6 @@ def _setup_plots(self): self.plots[f].axes.set_xlabel(labels[0]) self.plots[f].axes.set_ylabel(labels[1]) - color = self._background_color[f] - - self.plots[f].axes.set_facecolor(color) - # Determine the units of the data units = Unit(self.frb[f].units, registry=self.ds.unit_registry) units = units.latex_representation() @@ -1313,63 +1220,9 @@ def _setup_plots(self): else: self.plots[f].axes.minorticks_off() - # colorbar minorticks - if f not in self._cbar_minorticks: - self._cbar_minorticks[f] = True - - if self._cbar_minorticks[f]: - vmin = np.float64(self.plots[f].cb.norm.vmin) - vmax = np.float64(self.plots[f].cb.norm.vmax) - - if self._field_transform[f] == linear_transform: - self.plots[f].cax.minorticks_on() - - elif self._field_transform[f] == symlog_transform: - if Version("3.2.0") <= MPL_VERSION < Version("3.5.0b"): - # no known working method to draw symlog minor ticks - # see https://github.com/yt-project/yt/issues/3535 - pass - else: - flinthresh = 10 ** np.floor( - np.log10(self.plots[f].cb.norm.linthresh) - ) - absmax = np.abs((vmin, vmax)).max() - if (absmax - flinthresh) / absmax < 0.1: - flinthresh /= 10 - mticks = get_symlog_minorticks(flinthresh, vmin, vmax) - if MPL_VERSION < Version("3.5.0b"): - # https://github.com/matplotlib/matplotlib/issues/21258 - mticks = self.plots[f].image.norm(mticks) - self.plots[f].cax.yaxis.set_ticks(mticks, minor=True) - - elif self._field_transform[f] == log_transform: - if MPL_VERSION >= Version("3.0.0"): - self.plots[f].cax.minorticks_on() - self.plots[f].cax.xaxis.set_visible(False) - else: - mticks = self.plots[f].image.norm( - get_log_minorticks(vmin, vmax) - ) - self.plots[f].cax.yaxis.set_ticks(mticks, minor=True) - - else: - mylog.error( - "Unable to draw cbar minorticks for field " - "%s with transform %s ", - f, - self._field_transform[f], - ) - self._cbar_minorticks[f] = False - - if not self._cbar_minorticks[f]: - self.plots[f].cax.minorticks_off() - if not draw_axes: self.plots[f]._toggle_axes(draw_axes, draw_frame) - if not draw_colorbar: - self.plots[f]._toggle_colorbar(draw_colorbar) - self._set_font_properties() self.run_callbacks() @@ -2693,11 +2546,7 @@ class WindowPlotMPL(ImagePlotMPL): def __init__( self, data, - cbname, - cblinthresh, - cmap, extent, - zlim, figure_size, fontsize, aspect, @@ -2706,55 +2555,41 @@ def __init__( cax, mpl_proj, mpl_transform, + *, + norm_handler: NormHandler, + colorbar_handler: ColorbarHandler, ): - from matplotlib.ticker import ScalarFormatter - - self._draw_colorbar = True - self._draw_axes = True - self._draw_frame = True - self._fontsize = fontsize - self._figure_size = figure_size self._projection = mpl_proj self._transform = mpl_transform + self._setup_layout_constraints(figure_size, fontsize) + self._draw_frame = True + self._aspect = ((extent[1] - extent[0]) / (extent[3] - extent[2])).in_cgs() + self._unit_aspect = aspect + # Compute layout - fontscale = float(fontsize) / 18.0 + self._figure_size = figure_size + self._draw_axes = True + fontscale = float(fontsize) / self.__class__._default_font_size if fontscale < 1.0: fontscale = np.sqrt(fontscale) if is_sequence(figure_size): - fsize = figure_size[0] + self._cb_size = 0.0375 * figure_size[0] else: - fsize = figure_size - self._cb_size = 0.0375 * fsize + self._cb_size = 0.0375 * figure_size self._ax_text_size = [1.2 * fontscale, 0.9 * fontscale] self._top_buff_size = 0.30 * fontscale - self._aspect = ((extent[1] - extent[0]) / (extent[3] - extent[2])).in_cgs() - self._unit_aspect = aspect - - size, axrect, caxrect = self._get_best_layout() - - super().__init__(size, axrect, caxrect, zlim, figure, axes, cax) - self._init_image(data, cbname, cblinthresh, cmap, extent, aspect) + super().__init__( + figure=figure, + axes=axes, + cax=cax, + norm_handler=norm_handler, + colorbar_handler=colorbar_handler, + ) - # In matplotlib 2.1 and newer we'll be able to do this using - # self.image.axes.ticklabel_format - # See https://github.com/matplotlib/matplotlib/pull/6337 - formatter = ScalarFormatter(useMathText=True) - formatter.set_scientific(True) - formatter.set_powerlimits((-2, 3)) - self.image.axes.xaxis.set_major_formatter(formatter) - self.image.axes.yaxis.set_major_formatter(formatter) - if cbname == "linear": - self.cb.formatter.set_scientific(True) - try: - self.cb.formatter.set_useMathText(True) - except AttributeError: - # this is only available in mpl > 2.1 - pass - self.cb.formatter.set_powerlimits((-2, 3)) - self.cb.update_ticks() + self._init_image(data, extent, aspect) def _create_axes(self, axrect): self.axes = self.figure.add_axes(axrect, projection=self._projection) From 968d9d5dc7fc8d26a98f99e99c0809884ac1ea5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sun, 20 Mar 2022 09:09:44 +0100 Subject: [PATCH 04/54] RFC: refactor line plot to utilise NormHandler and ColorbarHandler --- yt/visualization/line_plot.py | 135 ++++++++++++++++------------------ 1 file changed, 62 insertions(+), 73 deletions(-) diff --git a/yt/visualization/line_plot.py b/yt/visualization/line_plot.py index 5df15c5d8d..2d44176138 100644 --- a/yt/visualization/line_plot.py +++ b/yt/visualization/line_plot.py @@ -1,17 +1,16 @@ from collections import defaultdict +from typing import Optional import numpy as np +from matplotlib.colors import LogNorm, Normalize, SymLogNorm from yt.funcs import is_sequence, mylog from yt.units.unit_object import Unit # type: ignore from yt.units.yt_array import YTArray -from yt.visualization.base_plot_types import PlotMPL from yt.visualization.plot_container import ( - PlotContainer, + BaseLinePlot, PlotDictionary, invalidate_plot, - linear_transform, - log_transform, ) @@ -87,10 +86,7 @@ def _sanitize_dimensions(self, item): ).dimensions if dimensions not in self.known_dimensions: self.known_dimensions[dimensions] = item - ret_item = item - else: - ret_item = self.known_dimensions[dimensions] - return ret_item + return self.known_dimensions[dimensions] def __getitem__(self, item): ret_item = self._sanitize_dimensions(item) @@ -105,7 +101,7 @@ def __contains__(self, item): return super().__contains__(ret_item) -class LinePlot(PlotContainer): +class LinePlot(BaseLinePlot): r""" A class for constructing line plots @@ -152,8 +148,12 @@ class LinePlot(PlotContainer): >>> plot.save() """ + _plot_dict_type = LinePlotDictionary _plot_type = "line_plot" + _default_figure_size = (5.0, 5.0) + _default_font_size = 14.0 + def __init__( self, ds, @@ -161,8 +161,8 @@ def __init__( start_point, end_point, npoints, - figure_size=5, - fontsize=14, + figure_size=None, + fontsize: Optional[float] = None, field_labels=None, ): """ @@ -175,25 +175,18 @@ def __init__( @classmethod def _initialize_instance( - cls, obj, ds, fields, figure_size=5, fontsize=14, field_labels=None + cls, obj, ds, fields, figure_size, fontsize, field_labels=None ): obj._x_unit = None - obj._y_units = {} obj._titles = {} data_source = ds.all_data() obj.fields = data_source._determine_fields(fields) - obj.plots = LinePlotDictionary(data_source) obj.include_legend = defaultdict(bool) - super(LinePlot, obj).__init__(data_source, figure_size, fontsize) - for f in obj.fields: - finfo = obj.data_source.ds._get_field_info(*f) - if finfo.take_log: - obj._field_transform[f] = log_transform - else: - obj._field_transform[f] = linear_transform - + super(LinePlot, obj).__init__( + data_source, figure_size=figure_size, fontsize=fontsize + ) if field_labels is None: obj.field_labels = {} else: @@ -202,9 +195,35 @@ def _initialize_instance( if f not in obj.field_labels: obj.field_labels[f] = f[1] + def _get_axrect(self): + fontscale = self._font_properties._size / self.__class__._default_font_size + top_buff_size = 0.35 * fontscale + + x_axis_size = 1.35 * fontscale + y_axis_size = 0.7 * fontscale + right_buff_size = 0.2 * fontscale + + if is_sequence(self.figure_size): + figure_size = self.figure_size + else: + figure_size = (self.figure_size, self.figure_size) + + xbins = np.array([x_axis_size, figure_size[0], right_buff_size]) + ybins = np.array([y_axis_size, figure_size[1], top_buff_size]) + + x_frac_widths = xbins / xbins.sum() + y_frac_widths = ybins / ybins.sum() + + return ( + x_frac_widths[0], + y_frac_widths[0], + x_frac_widths[1], + y_frac_widths[1], + ) + @classmethod def from_lines( - cls, ds, fields, lines, figure_size=5, font_size=14, field_labels=None + cls, ds, fields, lines, figure_size=None, font_size=None, field_labels=None ): """ A class method for constructing a line plot from multiple sampling lines @@ -252,41 +271,6 @@ def from_lines( obj._setup_plots() return obj - def _get_plot_instance(self, field): - fontscale = self._font_properties._size / 14.0 - top_buff_size = 0.35 * fontscale - - x_axis_size = 1.35 * fontscale - y_axis_size = 0.7 * fontscale - right_buff_size = 0.2 * fontscale - - if is_sequence(self.figure_size): - figure_size = self.figure_size - else: - figure_size = (self.figure_size, self.figure_size) - - xbins = np.array([x_axis_size, figure_size[0], right_buff_size]) - ybins = np.array([y_axis_size, figure_size[1], top_buff_size]) - - size = [xbins.sum(), ybins.sum()] - - x_frac_widths = xbins / size[0] - y_frac_widths = ybins / size[1] - - axrect = ( - x_frac_widths[0], - y_frac_widths[0], - x_frac_widths[1], - y_frac_widths[1], - ) - - try: - plot = self.plots[field] - except KeyError: - plot = PlotMPL(self.figure_size, axrect, None, None) - self.plots[field] = plot - return plot - def _setup_plots(self): if self._plot_valid: return @@ -315,13 +299,10 @@ def _setup_plots(self): else: unit_x = self._x_unit - if field in self._y_units: - unit_y = self._y_units[field] - else: - unit_y = y.units + unit_y = plot.norm_handler.display_units - x = x.to(unit_x) - y = y.to(unit_y) + x.convert_to_units(unit_x) + y.convert_to_units(unit_y) # determine legend label str_seq = [] @@ -334,11 +315,18 @@ def _setup_plots(self): plot.axes.plot(x, y, label=legend_label) # apply log transforms if requested - if self._field_transform[field] != linear_transform: - if (y <= 0).any(): - plot.axes.set_yscale("symlog") - else: - plot.axes.set_yscale("log") + norm = plot.norm_handler.get_norm(data=y) + y_norm_type = type(norm) + if y_norm_type is Normalize: + plot.axes.set_yscale("linear") + elif y_norm_type is LogNorm: + plot.axes.set_yscale("log") + elif y_norm_type is SymLogNorm: + plot.axes.set_yscale("symlog") + else: + raise NotImplementedError( + f"LinePlot doesn't support y norm with type {type(norm)}" + ) # set font properties plot._set_font_properties(self._font_properties, None) @@ -409,17 +397,18 @@ def set_x_unit(self, unit_name): self._x_unit = unit_name @invalidate_plot - def set_unit(self, field, unit_name): + def set_unit(self, field, new_unit): """Set the unit used to plot the field Parameters ---------- field: str or field tuple The name of the field to set the units for - unit_name: str - The name of the unit to use for this field + new_unit: string or Unit object """ - self._y_units[self.data_source._determine_fields(field)[0]] = unit_name + field = self.data_source._determine_fields(field)[0] + pnh = self.plots[field].norm_handler + pnh.display_units = new_unit @invalidate_plot def annotate_title(self, field, title): From ee4a49f49dc998b198436413f44fba5ea89f5041 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sun, 20 Mar 2022 09:10:02 +0100 Subject: [PATCH 05/54] RFC: refactor profile plots to utilise NormHandler and ColorbarHandler --- yt/visualization/profile_plotter.py | 308 ++++++++++------------------ 1 file changed, 113 insertions(+), 195 deletions(-) diff --git a/yt/visualization/profile_plotter.py b/yt/visualization/profile_plotter.py index b58086e865..573cc36611 100644 --- a/yt/visualization/profile_plotter.py +++ b/yt/visualization/profile_plotter.py @@ -1,38 +1,31 @@ import base64 import builtins import os -from collections import OrderedDict from functools import wraps -from typing import Any, Dict, Optional +from typing import Any, Dict, Iterable, Optional, Tuple, Union import matplotlib import numpy as np -from matplotlib.font_manager import FontProperties from more_itertools.more import always_iterable, unzip -from packaging.version import Version from yt.data_objects.profiles import create_profile, sanitize_field_tuple_keys from yt.data_objects.static_output import Dataset from yt.frontends.ytdata.data_structures import YTProfileDataset -from yt.funcs import is_sequence, iter_fields, matplotlib_style_context +from yt.funcs import iter_fields, matplotlib_style_context from yt.utilities.exceptions import YTNotInsideNotebook -from yt.utilities.logger import ytLogger as mylog +from yt.visualization._handlers import ColorbarHandler, NormHandler +from yt.visualization.base_plot_types import PlotMPL from ..data_objects.selection_objects.data_selection_objects import YTSelectionContainer -from ._commons import DEFAULT_FONT_PROPERTIES, validate_image_name -from .base_plot_types import ImagePlotMPL, PlotMPL +from ._commons import validate_image_name +from .base_plot_types import ImagePlotMPL from .plot_container import ( + BaseLinePlot, ImagePlotContainer, - PlotContainer, - get_log_minorticks, invalidate_plot, - linear_transform, - log_transform, validate_plot, ) -MPL_VERSION = Version(matplotlib.__version__) - def invalidate_profile(f): @wraps(f) @@ -44,42 +37,6 @@ def newfunc(*args, **kwargs): return newfunc -class PlotContainerDict(OrderedDict): - def __missing__(self, key): - plot = PlotMPL((10, 8), [0.1, 0.1, 0.8, 0.8], None, None) - self[key] = plot - return self[key] - - -class FigureContainer(OrderedDict): - def __init__(self, plots): - self.plots = plots - super().__init__() - - def __missing__(self, key): - self[key] = self.plots[key].figure - return self[key] - - def __iter__(self): - return iter(self.plots) - - -class AxesContainer(OrderedDict): - def __init__(self, plots): - self.plots = plots - self.ylim = {} - self.xlim = (None, None) - super().__init__() - - def __missing__(self, key): - self[key] = self.plots[key].axes - return self[key] - - def __setitem__(self, key, value): - super().__setitem__(key, value) - self.ylim[key] = (None, None) - - def sanitize_label(labels, nprofiles): labels = list(always_iterable(labels)) or [None] @@ -116,7 +73,7 @@ def data_object_or_all_data(data_source): return data_source -class ProfilePlot(PlotContainer): +class ProfilePlot(BaseLinePlot): r""" Create a 1d profile plot from a data source or from a list of profile objects. @@ -219,6 +176,8 @@ class ProfilePlot(PlotContainer): Use set_line_property to change line properties of one or all profiles. """ + _default_figure_size = (10.0, 8.0) + _default_font_size = 18.0 x_log = None y_log = None @@ -240,7 +199,6 @@ def __init__( x_log=True, y_log=True, ): - data_source = data_object_or_all_data(data_source) y_fields = list(iter_fields(y_fields)) logs = {x_field: bool(x_log)} @@ -268,7 +226,42 @@ def __init__( if not isinstance(plot_spec, list): plot_spec = [plot_spec.copy() for p in profiles] - ProfilePlot._initialize_instance(self, profiles, label, plot_spec, y_log) + ProfilePlot._initialize_instance( + self, data_source, profiles, label, plot_spec, y_log + ) + + @classmethod + def _initialize_instance( + cls, + obj, + data_source, + profiles, + labels, + plot_specs, + y_log, + ): + obj._plot_title = {} + obj._plot_text = {} + obj._text_xpos = {} + obj._text_ypos = {} + obj._text_kwargs = {} + + super(ProfilePlot, obj).__init__(data_source) + obj.profiles = list(always_iterable(profiles)) + obj.x_log = None + obj.y_log = sanitize_field_tuple_keys(y_log, data_source) or {} + obj.y_title = {} + obj.x_title = None + obj.label = sanitize_label(labels, len(obj.profiles)) + if plot_specs is None: + plot_specs = [dict() for p in obj.profiles] + obj.plot_spec = plot_specs + obj._xlim = (None, None) + obj._setup_plots() + return obj + + def _get_axrect(self): + return (0.1, 0.1, 0.8, 0.8) @validate_plot def save( @@ -293,13 +286,14 @@ def save( if not self._plot_valid: self._setup_plots() - # Mypy is hardly convinced that we have a `plots` and a `profile` attr + # Mypy is hardly convinced that we have a `profiles` attribute # at this stage, so we're lasily going to deactivate it locally - unique = set(self.plots.values()) # type: ignore - if len(unique) < len(self.plots): # type: ignore - iters = zip(range(len(unique)), sorted(unique)) + unique = set(self.plots.values()) + iters: Iterable[Tuple[Union[int, Tuple[str, str]], PlotMPL]] + if len(unique) < len(self.plots): + iters = enumerate(sorted(unique)) else: - iters = self.plots.items() # type: ignore + iters = self.plots.items() if name is None: if len(self.profiles) == 1: # type: ignore @@ -316,11 +310,10 @@ def save( names = [] for uid, plot in iters: - if isinstance(uid, tuple): # type: ignore + if isinstance(uid, tuple): uid = uid[1] # type: ignore uid_name = f"{prefix}_1d-Profile_{xfn}_{uid}{suffix}" names.append(uid_name) - mylog.info("Saving %s", uid_name) with matplotlib_style_context(): plot.save(uid_name, mpl_kwargs=mpl_kwargs) return names @@ -375,10 +368,10 @@ def _repr_html_(self): def _setup_plots(self): if self._plot_valid: return - for f in self.axes: - self.axes[f].cla() + for f, p in self.plots.items(): + p.axes.cla() if f in self._plot_text: - self.plots[f].axes.text( + p.axes.text( self._text_xpos[f], self._text_ypos[f], self._plot_text[f], @@ -389,7 +382,8 @@ def _setup_plots(self): for i, profile in enumerate(self.profiles): for field, field_data in profile.items(): - self.axes[field].plot( + plot = self._get_plot_instance(field) + plot.axes.plot( np.array(profile.x), np.array(field_data), label=self.label[i], @@ -398,7 +392,7 @@ def _setup_plots(self): for profile in self.profiles: for fname in profile.keys(): - axes = self.axes[fname] + axes = self.plots[fname].axes xscale, yscale = self._get_field_log(fname, profile) xtitle, ytitle = self._get_field_title(fname, profile) @@ -408,8 +402,10 @@ def _setup_plots(self): axes.set_ylabel(ytitle) axes.set_xlabel(xtitle) - axes.set_ylim(*self.axes.ylim[fname]) - axes.set_xlim(*self.axes.xlim) + pnh = self.plots[fname].norm_handler + + axes.set_ylim(pnh.vmin, pnh.vmax) + axes.set_xlim(*self._xlim) if fname in self._plot_title: axes.set_title(self._plot_title[fname]) @@ -419,31 +415,6 @@ def _setup_plots(self): self._set_font_properties() self._plot_valid = True - @classmethod - def _initialize_instance(cls, obj, profiles, labels, plot_specs, y_log): - obj._plot_title = {} - obj._plot_text = {} - obj._text_xpos = {} - obj._text_ypos = {} - obj._text_kwargs = {} - - obj._font_properties = FontProperties(**DEFAULT_FONT_PROPERTIES) - obj._font_color = None - obj.profiles = list(always_iterable(profiles)) - obj.x_log = None - obj.y_log = sanitize_field_tuple_keys(y_log, obj.profiles[0].data_source) or {} - obj.y_title = {} - obj.x_title = None - obj.label = sanitize_label(labels, len(obj.profiles)) - if plot_specs is None: - plot_specs = [dict() for p in obj.profiles] - obj.plot_spec = plot_specs - obj.plots = PlotContainerDict() - obj.figures = FigureContainer(obj.plots) - obj.axes = AxesContainer(obj.plots) - obj._setup_plots() - return obj - @classmethod def from_profiles(cls, profiles, labels=None, plot_specs=None, y_log=None): r""" @@ -497,7 +468,10 @@ def from_profiles(cls, profiles, labels=None, plot_specs=None, y_log=None): "Profiles list and plot_specs list must be the same size." ) obj = cls.__new__(cls) - return cls._initialize_instance(obj, profiles, labels, plot_specs, y_log) + profiles = list(always_iterable(profiles)) + return cls._initialize_instance( + obj, profiles[0].data_source, profiles, labels, plot_specs, y_log + ) @invalidate_plot def set_line_property(self, property, value, index=None): @@ -644,7 +618,7 @@ def set_xlim(self, xmin=None, xmax=None): >>> pp.save() """ - self.axes.xlim = (xmin, xmax) + self._xlim = (xmin, xmax) for i, p in enumerate(self.profiles): if xmin is None: xmi = p.x_bins.min() @@ -709,12 +683,14 @@ def set_ylim(self, field, ymin=None, ymax=None): >>> pp.save() """ - fields = list(self.axes.keys()) if field == "all" else field + fields = list(self.plots.keys()) if field == "all" else field for profile in self.profiles: for field in profile.data_source._determine_fields(fields): if field in profile.field_map: field = profile.field_map[field] - self.axes.ylim[field] = (ymin, ymax) + pnh = self.plots[field].norm_handler + pnh.vmin = ymin + pnh.vmax = ymax # Continue on to the next profile. break return self @@ -793,7 +769,7 @@ def annotate_title(self, title, field="all"): ... ) """ - fields = list(self.axes.keys()) if field == "all" else field + fields = list(self.plots.keys()) if field == "all" else field for profile in self.profiles: for field in profile.data_source._determine_fields(fields): if field in profile.field_map: @@ -845,7 +821,7 @@ def annotate_text(self, xpos=0.0, ypos=0.0, text=None, field="all", **text_kwarg >>> plot.save() """ - fields = list(self.axes.keys()) if field == "all" else field + fields = list(self.plots.keys()) if field == "all" else field for profile in self.profiles: for field in profile.data_source._determine_fields(fields): if field in profile.field_map: @@ -1059,10 +1035,6 @@ def _get_field_log(self, field_z, profile): scales = {True: "log", False: "linear"} return scales[x_log], scales[y_log], scales[z_log] - def _recreate_frb(self): - # needed for API compatibility with PlotWindow - pass - @property def profile(self): if not self._profile_valid: @@ -1080,45 +1052,26 @@ def _setup_plots(self): fig = None axes = None cax = None - draw_colorbar = True draw_axes = True - zlim = (None, None) xlim = self._xlim ylim = self._ylim if f in self.plots: - draw_colorbar = self.plots[f]._draw_colorbar + pnh = self.plots[f].norm_handler + cbh = self.plots[f].colorbar_handler draw_axes = self.plots[f]._draw_axes - zlim = (self.plots[f].zmin, self.plots[f].zmax) if self.plots[f].figure is not None: fig = self.plots[f].figure axes = self.plots[f].axes cax = self.plots[f].cax + else: + pnh, cbh = self._get_default_handlers( + field=f, default_display_units=self.profile[f].units + ) x_scale, y_scale, z_scale = self._get_field_log(f, self.profile) x_title, y_title, z_title = self._get_field_title(f, self.profile) - if zlim == (None, None): - if z_scale == "log": - positive_values = data[data > 0.0] - if len(positive_values) == 0: - mylog.warning( - "Profiled field %s has no positive values. Max = %f.", - f, - np.nanmax(data), - ) - mylog.warning("Switching to linear colorbar scaling.") - zmin = np.nanmin(data) - z_scale = "linear" - self._field_transform[f] = linear_transform - else: - zmin = positive_values.min() - self._field_transform[f] = log_transform - else: - zmin = np.nanmin(data) - self._field_transform[f] = linear_transform - zlim = [zmin, np.nanmax(data)] - font_size = self._font_properties.get_size() f = self.profile.data_source._determine_fields(f)[0] @@ -1126,9 +1079,7 @@ def _setup_plots(self): # override the colorbar here. splat_color = getattr(self, "splat_color", None) if splat_color is not None: - cmap = matplotlib.colors.ListedColormap(splat_color, "dummy") - else: - cmap = self._colormap_config[f] + cbh.cmap = matplotlib.colors.ListedColormap(splat_color, "dummy") masked_data = data.copy() masked_data[~self.profile.used] = np.nan @@ -1138,19 +1089,18 @@ def _setup_plots(self): masked_data, x_scale, y_scale, - z_scale, - cmap, - zlim, self.figure_size, font_size, fig, axes, cax, shading=self._shading, + norm_handler=pnh, + colorbar_handler=cbh, ) self.plots[f]._toggle_axes(draw_axes) - self.plots[f]._toggle_colorbar(draw_colorbar) + self.plots[f]._toggle_colorbar(cbh.draw_cbar) self.plots[f].axes.xaxis.set_label_text(x_title) self.plots[f].axes.yaxis.set_label_text(y_title) @@ -1159,10 +1109,6 @@ def _setup_plots(self): self.plots[f].axes.set_xlim(xlim) self.plots[f].axes.set_ylim(ylim) - color = self._background_color[f] - - self.plots[f].axes.set_facecolor(color) - if f in self._plot_text: self.plots[f].axes.text( self._text_xpos[f], @@ -1183,25 +1129,6 @@ def _setup_plots(self): else: self.plots[f].axes.minorticks_off() - # colorbar minorticks - if f not in self._cbar_minorticks: - self._cbar_minorticks[f] = True - if self._cbar_minorticks[f]: - if self._field_transform[f] == linear_transform: - self.plots[f].cax.minorticks_on() - elif MPL_VERSION < Version("3.0.0"): - # before matplotlib 3 log-scaled colorbars internally used - # a linear scale going from zero to one and did not draw - # minor ticks. Since we want minor ticks, calculate - # where the minor ticks should go in this linear scale - # and add them manually. - vmin = np.float64(self.plots[f].cb.norm.vmin) - vmax = np.float64(self.plots[f].cb.norm.vmax) - mticks = self.plots[f].image.norm(get_log_minorticks(vmin, vmax)) - self.plots[f].cax.yaxis.set_ticks(mticks, minor=True) - else: - self.plots[f].cax.minorticks_off() - self._set_font_properties() # if this is a particle plot with one color only, hide the cbar here @@ -1425,7 +1352,7 @@ def set_log(self, field, log): self.y_log = log self._profile_valid = False elif field in p.field_data: - self.z_log[field] = log + super().set_log(field, log) else: raise KeyError(f"Field {field} not in phase plot!") return self @@ -1449,7 +1376,7 @@ def set_unit(self, field, unit): self.profile.set_y_unit(unit) elif fd in self.profile.field_data.keys(): self.profile.set_field_unit(field, unit) - self.plots[field].zmin, self.plots[field].zmax = (None, None) + self.plots[field].norm_handler.display_units = unit else: raise KeyError(f"Field {field} not in phase plot!") return self @@ -1580,50 +1507,47 @@ def __init__( data, x_scale, y_scale, - z_scale, - cmap, - zlim, figure_size, fontsize, figure, axes, cax, shading="nearest", + *, + norm_handler: NormHandler, + colorbar_handler: ColorbarHandler, ): self._initfinished = False - self._draw_colorbar = True - self._draw_axes = True - self._figure_size = figure_size self._shading = shading - # Compute layout - fontscale = float(fontsize) / 18.0 - if fontscale < 1.0: - fontscale = np.sqrt(fontscale) - - if is_sequence(figure_size): - self._cb_size = 0.0375 * figure_size[0] - else: - self._cb_size = 0.0375 * figure_size - self._ax_text_size = [1.1 * fontscale, 0.9 * fontscale] - self._top_buff_size = 0.30 * fontscale - self._aspect = 1.0 - - size, axrect, caxrect = self._get_best_layout() - - super().__init__(size, axrect, caxrect, zlim, figure, axes, cax) + self._setup_layout_constraints(figure_size, fontsize) + + # this line is added purely to prevent exact image comparison tests + # to fail, but eventually we should embrace the change and + # use similar values for PhasePlotMPL and WindowPlotMPL + self._ax_text_size[0] *= 1.1 / 1.2 # TODO: remove this + + super().__init__( + figure=figure, + axes=axes, + cax=cax, + norm_handler=norm_handler, + colorbar_handler=colorbar_handler, + ) - self._init_image(x_data, y_data, data, x_scale, y_scale, z_scale, zlim, cmap) + self._init_image(x_data, y_data, data, x_scale, y_scale) self._initfinished = True def _init_image( - self, x_data, y_data, image_data, x_scale, y_scale, z_scale, zlim, cmap + self, + x_data, + y_data, + image_data, + x_scale, + y_scale, ): """Store output of imshow in image variable""" - if z_scale == "log": - norm = matplotlib.colors.LogNorm(zlim[0], zlim[1]) - elif z_scale == "linear": - norm = matplotlib.colors.Normalize(zlim[0], zlim[1]) + norm = self.norm_handler.get_norm(image_data) self.image = None self.cb = None @@ -1632,16 +1556,10 @@ def _init_image( np.array(y_data), np.array(image_data.T), norm=norm, - cmap=cmap, + cmap=self.colorbar_handler.cmap, shading=self._shading, ) + self._set_axes(norm) self.axes.set_xscale(x_scale) self.axes.set_yscale(y_scale) - self.cb = self.figure.colorbar(self.image, self.cax) - if z_scale == "linear": - self.cb.formatter.set_scientific(True) - self.cb.formatter.set_powerlimits((-2, 3)) - self.cb.update_ticks() - - self.cax.tick_params(which="both", axis="y", direction="in") From cb519db678e6a062714d06cbd0ec960f226e3dd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sun, 20 Mar 2022 09:11:08 +0100 Subject: [PATCH 06/54] TST: adapt tests to new NormHandler and ColorbarHandler mechanisms --- yt/visualization/tests/test_particle_plot.py | 17 +++++--- yt/visualization/tests/test_plotwindow.py | 45 ++++++++++++++------ 2 files changed, 43 insertions(+), 19 deletions(-) diff --git a/yt/visualization/tests/test_particle_plot.py b/yt/visualization/tests/test_particle_plot.py index b6cf68d64d..d5e060cb52 100644 --- a/yt/visualization/tests/test_particle_plot.py +++ b/yt/visualization/tests/test_particle_plot.py @@ -150,12 +150,17 @@ def formed_star(pfilter, data): ds.add_particle_filter("formed_star") for ax in "xyz": attr_name = "set_log" - for args in PROJ_ATTR_ARGS[attr_name]: - test = PlotWindowAttributeTest( - ds, plot_field, ax, attr_name, args, decimals, "ParticleProjectionPlot" - ) - test_particle_projection_filter.__name__ = test.description - yield test + test = PlotWindowAttributeTest( + ds, + plot_field, + ax, + attr_name, + ((plot_field, False), {}), + decimals, + "ParticleProjectionPlot", + ) + test_particle_projection_filter.__name__ = test.description + yield test @requires_ds(g30, big_data=True) diff --git a/yt/visualization/tests/test_plotwindow.py b/yt/visualization/tests/test_plotwindow.py index 9094eb9bd2..bf19692c86 100644 --- a/yt/visualization/tests/test_plotwindow.py +++ b/yt/visualization/tests/test_plotwindow.py @@ -5,10 +5,13 @@ from collections import OrderedDict import numpy as np +from matplotlib.colors import LogNorm, Normalize, SymLogNorm from nose.tools import assert_true +from unyt import unyt_array from yt.loaders import load_uniform_grid from yt.testing import ( + assert_allclose_units, assert_array_almost_equal, assert_array_equal, assert_equal, @@ -423,13 +426,13 @@ def setUp(self): fields_to_plot = fields + [("index", "radius")] if self.ds is None: self.ds = fake_random_ds(16, fields=fields, units=units) - self.slc = ProjectionPlot(self.ds, 0, fields_to_plot) + self.proj = ProjectionPlot(self.ds, 0, fields_to_plot) def tearDown(self): from yt.config import ytcfg del self.ds - del self.slc + del self.proj for key in self.newConfig.keys(): ytcfg.remove(*key) for key, val in self.oldConfig.items(): @@ -438,21 +441,37 @@ def tearDown(self): def test_units(self): from unyt import Unit - assert_equal(self.slc.frb["gas", "density"].units, Unit("mile*lb/yd**3")) - assert_equal(self.slc.frb["gas", "temperature"].units, Unit("cm*K")) - assert_equal(self.slc.frb["gas", "pressure"].units, Unit("dyn/cm")) + assert_equal(self.proj.frb["gas", "density"].units, Unit("mile*lb/yd**3")) + assert_equal(self.proj.frb["gas", "temperature"].units, Unit("cm*K")) + assert_equal(self.proj.frb["gas", "pressure"].units, Unit("dyn/cm")) def test_scale(self): - assert_equal(self.slc._field_transform["gas", "density"].name, "linear") - assert_equal(self.slc._field_transform["gas", "temperature"].name, "symlog") - assert_equal(self.slc._field_transform["gas", "temperature"].func, 100) - assert_equal(self.slc._field_transform["gas", "pressure"].name, "log10") - assert_equal(self.slc._field_transform["index", "radius"].name, "log10") + + assert_equal( + self.proj.plots["gas", "density"].norm_handler.norm_type, Normalize + ) + assert_equal( + self.proj.plots["gas", "temperature"].norm_handler.norm_type, SymLogNorm + ) + assert_allclose_units( + self.proj.plots["gas", "temperature"].norm_handler.linthresh, + unyt_array(100, "K*cm"), + ) + assert_equal(self.proj.plots["gas", "pressure"].norm_handler.norm_type, LogNorm) + assert_equal( + self.proj.plots["index", "radius"].norm_handler.norm_type, SymLogNorm + ) def test_cmap(self): - assert_equal(self.slc._colormap_config["gas", "density"], "plasma") - assert_equal(self.slc._colormap_config["gas", "temperature"], "hot") - assert_equal(self.slc._colormap_config["gas", "pressure"], "viridis") + assert_equal( + self.proj.plots["gas", "density"].colorbar_handler.cmap.name, "plasma" + ) + assert_equal( + self.proj.plots["gas", "temperature"].colorbar_handler.cmap.name, "hot" + ) + assert_equal( + self.proj.plots["gas", "pressure"].colorbar_handler.cmap.name, "viridis" + ) def test_on_off_compare(): From 26a5a4e973ae83a5c3c933add79c9327befa9363 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sun, 20 Mar 2022 09:12:15 +0100 Subject: [PATCH 07/54] RFC: adapt eps_writer to new norm handling paradigm --- yt/visualization/eps_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yt/visualization/eps_writer.py b/yt/visualization/eps_writer.py index e3aa1ffc1e..88109f5675 100644 --- a/yt/visualization/eps_writer.py +++ b/yt/visualization/eps_writer.py @@ -861,7 +861,7 @@ def colorbar_yt(self, plot, field=None, cb_labels=None, **kwargs): if field is not None: self.field = plot.data_source._determine_fields(field)[0] if isinstance(plot, (PlotWindow, PhasePlot)): - _cmap = plot._colormap_config[self.field] + _cmap = plot[self.field].colorbar_handler.cmap else: if plot.cmap is not None: _cmap = plot.cmap.name From 07326b7ce5b59bc86da20d6a5e82c56022fee439 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sun, 20 Mar 2022 09:13:05 +0100 Subject: [PATCH 08/54] DEPR: deprecate a user-exposed orphan function that was intended as internal --- yt/funcs.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/yt/funcs.py b/yt/funcs.py index 0803762680..17a6d05653 100644 --- a/yt/funcs.py +++ b/yt/funcs.py @@ -1258,6 +1258,12 @@ def dictWithFactory(factory: Callable[[Any], Any]) -> Type: A class to create new dictionaries handling missing keys. """ + issue_deprecation_warning( + "yt.funcs.dictWithFactory will be removed in a future version of yt, please do not rely on it. " + "If you need it, copy paste this function from yt's source code", + since="4.1", + ) + class DictWithFactory(dict): def __init__(self, *args, **kwargs): self.factory = factory From 94b177ca3176d2c4c6e06686872323d9dbb8f8ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sun, 20 Mar 2022 09:14:20 +0100 Subject: [PATCH 09/54] BLD: bump minimum requirement matplotlib 2.2.3 -> 3.1 --- conftest.py | 9 --------- setup.cfg | 4 ++-- tests/windows_conda_requirements.txt | 2 +- 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/conftest.py b/conftest.py index 61912375c0..11cced7614 100644 --- a/conftest.py +++ b/conftest.py @@ -115,15 +115,6 @@ def pytest_configure(config): ): config.addinivalue_line("filterwarnings", value) - if MPL_VERSION < Version("3.0.0"): - config.addinivalue_line( - "filterwarnings", - ( - "ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc' " - "is deprecated since Python 3.3,and in 3.9 it will stop working:DeprecationWarning" - ), - ) - if MPL_VERSION < Version("3.5.2"): if MPL_VERSION < Version("3.3"): try: diff --git a/setup.cfg b/setup.cfg index eb841a731c..d0cbea441d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,7 +38,7 @@ project_urls = packages = find: install_requires = cmyt>=0.2.2 - matplotlib!=3.4.2,>=2.2.3 # keep in sync with tests/windows_conda_requirements.txt + matplotlib!=3.4.2,>=3.1 # keep in sync with tests/windows_conda_requirements.txt more-itertools>=8.4 numpy>=1.14.5 packaging>=20.9 @@ -99,7 +99,7 @@ mapserver = bottle minimal = cmyt==0.2.2 - matplotlib==2.2.3 + matplotlib==3.1 more-itertools==8.4 numpy==1.14.5 tomli==1.2.3 diff --git a/tests/windows_conda_requirements.txt b/tests/windows_conda_requirements.txt index fbec89d138..60dfa90957 100644 --- a/tests/windows_conda_requirements.txt +++ b/tests/windows_conda_requirements.txt @@ -2,5 +2,5 @@ numpy>=1.19.4 cython>=0.29.21,<3.0 cartopy>=0.20.1 h5py~=3.1.0 -matplotlib!=3.4.2,>=2.2.3 # keep in sync with setup.cfg +matplotlib!=3.4.2,>=3.1 # keep in sync with setup.cfg scipy~=1.5.0 From 9580308ede108886e500acbc8f5bc562df52246e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sun, 20 Mar 2022 09:14:46 +0100 Subject: [PATCH 10/54] DOC: update documentation on norm and colorbar API --- doc/source/visualizing/plots.rst | 119 +++++++++++++++++++++++-------- 1 file changed, 91 insertions(+), 28 deletions(-) diff --git a/doc/source/visualizing/plots.rst b/doc/source/visualizing/plots.rst index 0f4a6d4ba9..1ae98e98ea 100644 --- a/doc/source/visualizing/plots.rst +++ b/doc/source/visualizing/plots.rst @@ -724,6 +724,9 @@ the axes unit labels. The same result could have been accomplished by explicitly setting the ``width`` to ``(.01, 'Mpc')``. + +.. _set-image-units: + Set image units ~~~~~~~~~~~~~~~ @@ -903,7 +906,7 @@ customization. Colormaps ~~~~~~~~~ -Each of these functions accept two arguments. In all cases the first argument +Each of these functions accepts at least two arguments. In all cases the first argument is a field name. This makes it possible to use different custom colormaps for different fields tracked by the plot object. @@ -920,10 +923,45 @@ Use any of the colormaps listed in the :ref:`colormaps` section. slc.set_cmap(("gas", "density"), "RdBu_r") slc.save() -The :meth:`~yt.visualization.plot_window.AxisAlignedSlicePlot.set_log` function -accepts a field name and a boolean. If the boolean is ``True``, the colormap -for the field will be log scaled. If it is ``False`` the colormap will be -linear. +Colorbar norms +:::::::::::::: + +Slice plots and similar plot classes default to log norms when all values are +strictly positive, and symlog otherwise. yt supports two different interfaces to +move away from the defaults. See **constrained norms** and **arbitrary norm** +hereafter. + +**Constrained norms** + +The norm properties can be constrained via two methods + +- :meth:`~yt.visualization.plot_container.PlotContainer.set_zlim` controls the extrema + of the value range ``zmin`` and ``zmax``. +- :meth:`~yt.visualization.plot_container.ImagePlotContainer.set_log` allows switching to + linear or symlog norms. With symlog, the linear threshold can be set + explicitly. Otherwise, yt will dynamically determine a resonable value. + +Use the :meth:`~yt.visualization.plot_container.PlotContainer.set_zlim` +method to set a custom colormap range. + +.. python-script:: + + import yt + + ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") + slc = yt.SlicePlot(ds, "z", ("gas", "density"), width=(10, "kpc")) + slc.set_zlim(("gas", "density"), zmin=(1e-30, "g/cm**3)), zmax=(1e-25, "g/cm**3")) + slc.save() + +Units can be left out, in which case they implicitly match the current display +units of the colorbar (controlled with the ``set_unit`` method, see +:ref:`_set-image-units`). + +Both ``zmin`` and ``zmax`` are optional, but note that they respectively default +to 0 and 1, which can be widely inapropriate, so it is recommended to specify both. + +:meth:`~yt.visualization.plot_container.ImagePlotContainer.set_log` takes a boolean argument +to select log (``True``) or linear (``False``) scalings. .. python-script:: @@ -931,19 +969,30 @@ linear. ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") slc = yt.SlicePlot(ds, "z", ("gas", "density"), width=(10, "kpc")) - slc.set_log(("gas", "density"), False) + slc.set_log(("gas", "density"), False) # switch to linear scaling slc.save() -Specifically, a field containing both positive and negative values can be plotted -with symlog scale, by setting the boolean to be ``True`` and either providing an extra -parameter ``linthresh`` or setting ``symlog_auto = True``. In the region around zero -(when the log scale approaches to infinity), the linear scale will be applied to the -region ``(-linthresh, linthresh)`` and stretched relative to the logarithmic range. -In some cases, if yt detects zeros present in the dataset and the user has selected -``log`` scaling, yt automatically switches to ``symlog`` scaling and automatically -chooses a ``linthresh`` value to avoid errors. This is the same behavior you can -achieve by setting the keyword ``symlog_auto`` to ``True``. In these cases, yt will -choose the smallest non-zero value in a dataset to be the ``linthresh`` value. +One can switch to `symlog +`_ +by providing a "linear threshold" (``linthresh``) + +.. python-script:: + + import yt + + ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") + slc = yt.SlicePlot(ds, "z", ("gas", "velocity_x"), width=(10, "kpc")) + slc.set_log(("gas", "density"), linthresh=(1e-26, "g/cm**3")) + slc.save() + +Similar to the ``zmin`` and ``zmax`` arguments of the ``set_zlim`` method, units +can be left out in ``linthresh``. + + ``linthresh="auto"`` is also +valid, in this case yt will switch to symlog norm and guess an appropriate value +automatically. Specifically the minimum absolute value in the image is used +unless it's zero, in which case yt uses 1/1000 of the maximum value. + As an example, .. python-script:: @@ -952,7 +1001,7 @@ As an example, ds = yt.load_sample("FIRE_M12i_ref11") p = yt.ProjectionPlot(ds, "x", ("gas", "density")) - p.set_log(("gas", "density"), True, symlog_auto=True) + p.set_log(("gas", "density"), linthresh="auto") p.save() Symlog is very versatile, and will work with positive or negative dataset ranges. @@ -965,7 +1014,31 @@ Here is an example using symlog scaling to plot a positive field with a linear r ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") slc = yt.SlicePlot(ds, "z", ("gas", "velocity_x"), width=(30, "kpc")) - slc.set_log(("gas", "velocity_x"), True, linthresh=1.0e1) + slc.set_log(("gas", "velocity_x"), linthresh=1.0e1) + slc.save() + +**Arbitrary norms** + +Alternatively, arbitrary matplotlib norms can be passed via the +:meth:`~yt.visualization.plot_container.PlotContainer.set_norm` method. In that +case, any numeric value is treated as having implicit units, matching the +current display units. This alternative interface is more flexible, but +considered experimental as of yt 4.1. Don't forget that with great power comes +great responsibility. + + +.. python-script:: + + import yt + from matplotlib.colors import TwoSlopeNorm + + ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") + slc = yt.SlicePlot(ds, "z", ("gas", "velocity_x"), width=(30, "kpc")) + slc.set_norm(("gas", "velocity_x"), TwoSlopeNorm(vcenter=0)) + + # using a diverging colormap to emphasize that vcenter corresponds to the + # middle value in the color range + slc.set_cmap(("gas", "velocity_x"), "RdBu") slc.save() The :meth:`~yt.visualization.plot_container.ImagePlotContainer.set_background_color` @@ -1000,17 +1073,7 @@ you will need to make use of the ``draw_frame`` keyword argument for the ``hide_ slc.hide_colorbar() slc.save("just_image") -Lastly, the :meth:`~yt.visualization.plot_window.AxisAlignedSlicePlot.set_zlim` -function makes it possible to set a custom colormap range. - -.. python-script:: - import yt - - ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") - slc = yt.SlicePlot(ds, "z", ("gas", "density"), width=(10, "kpc")) - slc.set_zlim(("gas", "density"), 1e-30, 1e-25) - slc.save() Annotations ~~~~~~~~~~~ From 3fb4af0c075b1e91fa141b575b33afba1d315320 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sun, 20 Mar 2022 18:22:19 +0100 Subject: [PATCH 11/54] ENH: improve coherence between ImagePlotContainer.hide_axes and ImagePlotContainer.set_background_color --- doc/source/visualizing/plots.rst | 17 ----------------- yt/visualization/_handlers.py | 10 +++++++--- yt/visualization/base_plot_types.py | 15 ++++++++++++--- yt/visualization/plot_container.py | 4 ++-- yt/visualization/plot_window.py | 2 +- 5 files changed, 22 insertions(+), 26 deletions(-) diff --git a/doc/source/visualizing/plots.rst b/doc/source/visualizing/plots.rst index 1ae98e98ea..eb7aa0789c 100644 --- a/doc/source/visualizing/plots.rst +++ b/doc/source/visualizing/plots.rst @@ -1057,23 +1057,6 @@ value of the color map. slc.set_background_color(("gas", "density"), color="black") slc.save("black_background") -If you would like to change the background for a plot and also hide the axes, -you will need to make use of the ``draw_frame`` keyword argument for the ``hide_axes`` function. If you do not use this keyword argument, the call to -``set_background_color`` will have no effect. Here is an example illustrating how to use the ``draw_frame`` keyword argument for ``hide_axes``: - -.. python-script:: - - import yt - - ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") - field = ("deposit", "all_density") - slc = yt.ProjectionPlot(ds, "z", field, width=(1.5, "Mpc")) - slc.set_background_color(field) - slc.hide_axes(draw_frame=True) - slc.hide_colorbar() - slc.save("just_image") - - Annotations ~~~~~~~~~~~ diff --git a/yt/visualization/_handlers.py b/yt/visualization/_handlers.py index b3392ea3fe..61a6f35ce8 100644 --- a/yt/visualization/_handlers.py +++ b/yt/visualization/_handlers.py @@ -294,7 +294,7 @@ def __init__( draw_cbar: bool = True, draw_minorticks: bool = True, cmap: Optional[Union[Colormap, str]] = None, - background_color: Optional[str] = "white", + background_color: Optional[str] = None, ): self._draw_cbar = draw_cbar self._draw_minorticks = draw_minorticks @@ -346,8 +346,8 @@ def cmap(self, newval) -> None: ) @property - def background_color(self): - return self._background_color + def background_color(self) -> str: + return self._background_color or "white" @background_color.setter def background_color(self, newval): @@ -355,3 +355,7 @@ def background_color(self, newval): self._background_color = self.cmap(0) else: self._background_color = newval + + @property + def has_background_color(self) -> bool: + return self._background_color is not None diff --git a/yt/visualization/base_plot_types.py b/yt/visualization/base_plot_types.py index 41f7530c35..99c294f3ed 100644 --- a/yt/visualization/base_plot_types.py +++ b/yt/visualization/base_plot_types.py @@ -1,3 +1,4 @@ +import warnings from abc import ABC from io import BytesIO from typing import Optional, Tuple, Union @@ -476,10 +477,18 @@ def _toggle_axes(self, choice, draw_frame=None): If True, set the axes to be drawn. If False, set the axes to not be drawn. """ - if draw_frame is None: - draw_frame = choice self._draw_axes = choice self._draw_frame = draw_frame + if draw_frame is None: + draw_frame = choice + if self.colorbar_handler.has_background_color and not draw_frame: + # workaround matplotlib's behaviour + # last checked with Matplotlib 3.5 + warnings.warn( + f"Previously set background color {self.colorbar_handler.background_color} " + "has no effect. Pass `draw_axis=True` if you wish to preserve background color.", + stacklevel=4, + ) self.axes.set_frame_on(draw_frame) self.axes.get_xaxis().set_visible(choice) self.axes.get_yaxis().set_visible(choice) @@ -505,7 +514,7 @@ def _get_labels(self): labels += [cbax.yaxis.label, cbax.yaxis.get_offset_text()] return labels - def hide_axes(self, draw_frame=None): + def hide_axes(self, *, draw_frame=None): """ Hide the axes for a plot including ticks and labels """ diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index 3b3fbbe81a..7107c03391 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -795,7 +795,7 @@ def show_colorbar(self, field=None): self.plots[f].show_colorbar() return self - def hide_axes(self, field=None, draw_frame=False): + def hide_axes(self, field=None, draw_frame=None): """ Hides the axes for a plot and updates the size of the plot accordingly. Defaults to operating on all fields for a @@ -841,7 +841,7 @@ def hide_axes(self, field=None, draw_frame=False): if field is None: field = self.fields for f in iter_fields(field): - self.plots[f].hide_axes(draw_frame) + self.plots[f].hide_axes(draw_frame=draw_frame) return self def show_axes(self, field=None): diff --git a/yt/visualization/plot_window.py b/yt/visualization/plot_window.py index 709c2e1ade..e2c047b8ef 100644 --- a/yt/visualization/plot_window.py +++ b/yt/visualization/plot_window.py @@ -1090,7 +1090,7 @@ def _setup_plots(self): axes = None cax = None draw_axes = True - draw_frame = draw_axes + draw_frame = None if f in self.plots: draw_axes = self.plots[f]._draw_axes draw_frame = self.plots[f]._draw_frame From afb3d17969e2df358371c0d96d2a9efb417711bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sat, 16 Apr 2022 17:11:23 +0200 Subject: [PATCH 12/54] TST: add image tests for norm api --- tests/tests.yaml | 12 ++++++ .../tests/test_norm_api_custom_norm.py | 42 +++++++++++++++++++ .../tests/test_norm_api_inf_zlim.py | 39 +++++++++++++++++ .../tests/test_norm_api_lineplot.py | 32 ++++++++++++++ .../tests/test_norm_api_particleplot.py | 29 +++++++++++++ ...orm_api_phaseplot_set_colorbar_explicit.py | 33 +++++++++++++++ ...orm_api_phaseplot_set_colorbar_implicit.py | 33 +++++++++++++++ .../tests/test_norm_api_profileplot.py | 29 +++++++++++++ .../test_norm_api_set_background_color.py | 24 +++++++++++ .../tests/test_norm_api_set_unit_and_zlim.py | 22 ++++++++++ .../tests/test_norm_api_set_zlim_and_unit.py | 22 ++++++++++ 11 files changed, 317 insertions(+) create mode 100644 yt/visualization/tests/test_norm_api_custom_norm.py create mode 100644 yt/visualization/tests/test_norm_api_inf_zlim.py create mode 100644 yt/visualization/tests/test_norm_api_lineplot.py create mode 100644 yt/visualization/tests/test_norm_api_particleplot.py create mode 100644 yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_explicit.py create mode 100644 yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_implicit.py create mode 100644 yt/visualization/tests/test_norm_api_profileplot.py create mode 100644 yt/visualization/tests/test_norm_api_set_background_color.py create mode 100644 yt/visualization/tests/test_norm_api_set_unit_and_zlim.py create mode 100644 yt/visualization/tests/test_norm_api_set_zlim_and_unit.py diff --git a/tests/tests.yaml b/tests/tests.yaml index 6455995fd0..deb3b67fe5 100644 --- a/tests/tests.yaml +++ b/tests/tests.yaml @@ -183,6 +183,18 @@ answer_tests: local_nc4_cm1_001: # PR 2176 - yt/frontends/nc4_cm1/tests/test_outputs.py:test_cm1_mesh_fields + local_norm_api_007: # PR 3849 + - yt/visualization/tests/test_norm_api_lineplot.py:test_lineplot_set_axis_properties + - yt/visualization/tests/test_norm_api_profileplot.py:test_profileplot_set_axis_properties + - yt/visualization/tests/test_norm_api_custom_norm.py:test_sliceplot_custom_norm + - yt/visualization/tests/test_norm_api_set_zlim_and_unit.py:test_sliceplot_set_zlim_and_unit + - yt/visualization/tests/test_norm_api_set_unit_and_zlim.py:test_sliceplot_set_unit_and_zlim + - yt/visualization/tests/test_norm_api_set_background_color.py:test_sliceplot_set_background_color + - yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_implicit.py:test_phaseplot_set_colorbar_properties_implicit + - yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_explicit.py:test_phaseplot_set_colorbar_properties_explicit + - yt/visualization/tests/test_norm_api_particleplot.py:test_particleprojectionplot_set_colorbar_properties + - yt/visualization/tests/test_norm_api_inf_zlim.py:test_inf_and_finite_values_zlim + other_tests: unittests: - "--exclude=test_mesh_slices" # disable randomly failing test diff --git a/yt/visualization/tests/test_norm_api_custom_norm.py b/yt/visualization/tests/test_norm_api_custom_norm.py new file mode 100644 index 0000000000..67ccca6f1a --- /dev/null +++ b/yt/visualization/tests/test_norm_api_custom_norm.py @@ -0,0 +1,42 @@ +import matplotlib +from matplotlib.colors import LogNorm, Normalize, SymLogNorm +from nose.plugins.attrib import attr +from packaging.version import Version + +from yt.testing import ANSWER_TEST_TAG, fake_random_ds, skip_case +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import SlicePlot + +MPL_VERSION = Version(matplotlib.__version__) + + +@attr(ANSWER_TEST_TAG) +def test_sliceplot_custom_norm(): + if MPL_VERSION < Version("3.4"): + skip_case( + reason="in MPL<3.4, SymLogNorm emits a deprecation warning " + "that cannot be easily filtered" + ) + # don't import this at top level because it's only available since MPL 3.2 + from matplotlib.colors import TwoSlopeNorm + + norms_to_test = [ + (Normalize(), "linear"), + (LogNorm(), "log"), + (TwoSlopeNorm(vcenter=0, vmin=-0.5, vmax=1), "twoslope"), + (SymLogNorm(linthresh=0.01, vmin=-1, vmax=1), "symlog"), + ] + + ds = fake_random_ds(16) + + def create_image(filename_prefix): + field = ("gas", "density") + for norm, name in norms_to_test: + p = SlicePlot(ds, "z", field) + p.set_norm(field, norm=norm) + p.save(f"{filename_prefix}_{name}") + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_sliceplot_custom_norm" + test.answer_name = "sliceplot_custom_norm" + yield test diff --git a/yt/visualization/tests/test_norm_api_inf_zlim.py b/yt/visualization/tests/test_norm_api_inf_zlim.py new file mode 100644 index 0000000000..17c89e3d16 --- /dev/null +++ b/yt/visualization/tests/test_norm_api_inf_zlim.py @@ -0,0 +1,39 @@ +import numpy as np +from nose.plugins.attrib import attr + +from yt.loaders import load_uniform_grid +from yt.testing import ANSWER_TEST_TAG +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import SlicePlot + + +@attr(ANSWER_TEST_TAG) +def test_inf_and_finite_values_zlim(): + # see https://github.com/yt-project/yt/issues/3901 + shape = (32, 16, 1) + a = np.ones(16) + b = np.ones((32, 16)) + c = np.reshape(a * b, shape) + + # injecting an inf + c[0, 0, 0] = np.inf + + data = {("gas", "density"): c} + + ds = load_uniform_grid( + data, + shape, + bbox=np.array([[0, 1], [0, 1], [0, 1]]), + ) + + def create_image(filename_prefix): + p = SlicePlot(ds, "z", ("gas", "density")) + + # setting zlim manually + p.set_zlim(("gas", "density"), -10, 10) + p.save(filename_prefix) + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_inf_and_finite_values_zlim" + test.answer_name = "inf_and_finite_values_zlim" + yield test diff --git a/yt/visualization/tests/test_norm_api_lineplot.py b/yt/visualization/tests/test_norm_api_lineplot.py new file mode 100644 index 0000000000..3da3b7aa96 --- /dev/null +++ b/yt/visualization/tests/test_norm_api_lineplot.py @@ -0,0 +1,32 @@ +from nose.plugins.attrib import attr + +from yt.testing import ANSWER_TEST_TAG, fake_random_ds +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import LinePlot + + +@attr(ANSWER_TEST_TAG) +def test_lineplot_set_axis_properties(): + ds = fake_random_ds(16) + + def create_image(filename_prefix): + p = LinePlot( + ds, + ("gas", "density"), + start_point=[0, 0, 0], + end_point=[1, 1, 1], + npoints=32, + ) + p.set_x_unit("cm") + p.save(f"{filename_prefix}_xunit") + + p.set_unit(("gas", "density"), "kg/cm**3") + p.save(f"{filename_prefix}_xunit_zunit") + + p.set_log(("gas", "density"), False) + p.save(f"{filename_prefix}_xunit_zunit_lin") + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_lineplot_set_axis_properties" + test.answer_name = "lineplot_set_axis_properties" + yield test diff --git a/yt/visualization/tests/test_norm_api_particleplot.py b/yt/visualization/tests/test_norm_api_particleplot.py new file mode 100644 index 0000000000..0cd2167a29 --- /dev/null +++ b/yt/visualization/tests/test_norm_api_particleplot.py @@ -0,0 +1,29 @@ +from nose.plugins.attrib import attr + +from yt.testing import ANSWER_TEST_TAG, fake_particle_ds +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import ParticleProjectionPlot + + +@attr(ANSWER_TEST_TAG) +def test_particleprojectionplot_set_colorbar_properties(): + ds = fake_particle_ds(npart=100) + + def create_image(filename_prefix): + field = ("all", "particle_mass") + p = ParticleProjectionPlot(ds, 2, field) + p.set_buff_size(10) + + p.set_unit(field, "Msun") + p.save(f"{filename_prefix}_set_unit") + + p.set_zlim(field, zmax=1e-35) + p.save(f"{filename_prefix}_set_unit_zlim") + + p.set_log(field, False) + p.save(f"{filename_prefix}_set_unit_zlim_log") + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_particleprojectionplot_set_colorbar_properties" + test.answer_name = "particleprojectionplot_set_colorbar_properties" + yield test diff --git a/yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_explicit.py b/yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_explicit.py new file mode 100644 index 0000000000..8125040b6a --- /dev/null +++ b/yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_explicit.py @@ -0,0 +1,33 @@ +from nose.plugins.attrib import attr + +from yt.testing import ANSWER_TEST_TAG, add_noise_fields, fake_random_ds +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import PhasePlot + + +@attr(ANSWER_TEST_TAG) +def test_phaseplot_set_colorbar_properties_explicit(): + ds = fake_random_ds(16) + add_noise_fields(ds) + + def create_image(filename_prefix): + my_sphere = ds.sphere("c", 1) + p = PhasePlot( + my_sphere, + ("gas", "noise1"), + ("gas", "noise3"), + [("gas", "density")], + weight_field=None, + ) + # using explicit units, we expect the colorbar units to stay unchanged + p.set_zlim(("gas", "density"), zmin=(1e36, "kg/AU**3")) + p.save(f"{filename_prefix}_set_zlim_explicit") + + # ... until we set them explicitly + p.set_unit(("gas", "density"), "kg/AU**3") + p.save(f"{filename_prefix}_set_zlim_set_unit_explicit") + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_phaseplot_set_colorbar_properties_explicit" + test.answer_name = "phaseplot_set_colorbar_properties_explicit" + yield test diff --git a/yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_implicit.py b/yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_implicit.py new file mode 100644 index 0000000000..6bb1f29513 --- /dev/null +++ b/yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_implicit.py @@ -0,0 +1,33 @@ +from nose.plugins.attrib import attr + +from yt.testing import ANSWER_TEST_TAG, add_noise_fields, fake_random_ds +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import PhasePlot + + +@attr(ANSWER_TEST_TAG) +def test_phaseplot_set_colorbar_properties_implicit(): + ds = fake_random_ds(16) + add_noise_fields(ds) + + def create_image(filename_prefix): + my_sphere = ds.sphere("c", 1) + p = PhasePlot( + my_sphere, + ("gas", "noise1"), + ("gas", "noise3"), + [("gas", "density")], + weight_field=None, + ) + # using implicit units + p.set_zlim(("gas", "density"), zmax=10) + p.save(f"{filename_prefix}_set_zlim_implicit") + + # changing units should affect the colorbar and not the image + p.set_unit(("gas", "density"), "kg/AU**3") + p.save(f"{filename_prefix}_set_zlim_set_unit_implicit") + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_phaseplot_set_colorbar_properties_implicit" + test.answer_name = "phaseplot_set_colorbar_properties_implicit" + yield test diff --git a/yt/visualization/tests/test_norm_api_profileplot.py b/yt/visualization/tests/test_norm_api_profileplot.py new file mode 100644 index 0000000000..ee0e66ef41 --- /dev/null +++ b/yt/visualization/tests/test_norm_api_profileplot.py @@ -0,0 +1,29 @@ +from nose.plugins.attrib import attr + +from yt.testing import ANSWER_TEST_TAG, fake_random_ds +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import ProfilePlot + + +@attr(ANSWER_TEST_TAG) +def test_profileplot_set_axis_properties(): + ds = fake_random_ds(16) + + def create_image(filename_prefix): + disk = ds.disk(ds.domain_center, [0.0, 0.0, 1.0], (10, "m"), (1, "m")) + p = ProfilePlot(disk, ("gas", "density"), [("gas", "velocity_x")]) + p.save(f"{filename_prefix}_defaults") + + p.set_unit(("gas", "density"), "kg/cm**3") + p.save(f"{filename_prefix}_xunit") + + p.set_log(("gas", "density"), False) + p.save(f"{filename_prefix}_xunit_xlin") + + p.set_unit(("gas", "velocity_x"), "mile/hour") + p.save(f"{filename_prefix}_xunit_xlin_yunit") + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_profileplot_set_axis_properties" + test.answer_name = "profileplot_set_axis_properties" + yield test diff --git a/yt/visualization/tests/test_norm_api_set_background_color.py b/yt/visualization/tests/test_norm_api_set_background_color.py new file mode 100644 index 0000000000..a9d92b7247 --- /dev/null +++ b/yt/visualization/tests/test_norm_api_set_background_color.py @@ -0,0 +1,24 @@ +from nose.plugins.attrib import attr + +from yt.testing import ANSWER_TEST_TAG, fake_random_ds +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import SlicePlot + + +@attr(ANSWER_TEST_TAG) +def test_sliceplot_set_background_color(): + # see https://github.com/yt-project/yt/issues/3854 + ds = fake_random_ds(16) + + def create_image(filename_prefix): + field = ("gas", "density") + p = SlicePlot(ds, "z", field, width=1.5) + p.set_background_color(field, color="C0") + p.save(f"{filename_prefix}_log") + p.set_log(("gas", "density"), False) + p.save(f"{filename_prefix}_lin") + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_sliceplot_set_background_color" + test.answer_name = "sliceplot_set_background_color" + yield test diff --git a/yt/visualization/tests/test_norm_api_set_unit_and_zlim.py b/yt/visualization/tests/test_norm_api_set_unit_and_zlim.py new file mode 100644 index 0000000000..1e78fe75be --- /dev/null +++ b/yt/visualization/tests/test_norm_api_set_unit_and_zlim.py @@ -0,0 +1,22 @@ +from nose.plugins.attrib import attr + +from yt.testing import ANSWER_TEST_TAG, fake_random_ds +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import SlicePlot + + +@attr(ANSWER_TEST_TAG) +def test_sliceplot_set_unit_and_zlim(): + ds = fake_random_ds(16) + + def create_image(filename_prefix): + field = ("gas", "density") + p = SlicePlot(ds, "z", field) + p.set_unit(field, "kg/m**3") + p.set_zlim(field, zmin=0) + p.save(filename_prefix) + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_sliceplot_set_unit_and_zlim" + test.answer_name = "sliceplot_set_unit_and_zlim" + yield test diff --git a/yt/visualization/tests/test_norm_api_set_zlim_and_unit.py b/yt/visualization/tests/test_norm_api_set_zlim_and_unit.py new file mode 100644 index 0000000000..41935974b1 --- /dev/null +++ b/yt/visualization/tests/test_norm_api_set_zlim_and_unit.py @@ -0,0 +1,22 @@ +from nose.plugins.attrib import attr + +from yt.testing import ANSWER_TEST_TAG, fake_random_ds +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import SlicePlot + +ds = fake_random_ds(16) + + +@attr(ANSWER_TEST_TAG) +def test_sliceplot_set_zlim_and_unit(): + def create_image(filename_prefix): + field = ("gas", "density") + p = SlicePlot(ds, "z", field) + p.set_zlim(field, zmin=0) + p.set_unit(field, "kg/m**3") + p.save(filename_prefix) + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_sliceplot_set_zlim_and_unit" + test.answer_name = "sliceplot_set_zlim_and_unit" + yield test From 5ac1a3971f07083ae07318cac234f6936787e87d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Tue, 3 May 2022 22:03:29 +0200 Subject: [PATCH 13/54] BUG: fix blocking behaviours in tests/report_failed_answers.py --- tests/report_failed_answers.py | 7 +++++++ yt/utilities/answer_testing/framework.py | 1 + 2 files changed, 8 insertions(+) diff --git a/tests/report_failed_answers.py b/tests/report_failed_answers.py index 8cbfa6c7f5..78a7c55d68 100644 --- a/tests/report_failed_answers.py +++ b/tests/report_failed_answers.py @@ -13,6 +13,7 @@ import os import re import shutil +import sys import tempfile import xml.etree.ElementTree as ET @@ -425,6 +426,9 @@ def handle_error(error, testcase, missing_errors, missing_answers, failed_answer + "\n" ) response = upload_answers(failed_answers) + if response is None: + log.error("Failed to upload answers for failed tests !") + sys.exit(1) if response.ok: msg += ( FLAG_EMOJI @@ -438,6 +442,9 @@ def handle_error(error, testcase, missing_errors, missing_answers, failed_answer if args.upload_missing_answers and missing_answers: response = upload_answers(missing_answers) + if response is None: + log.error("Failed to upload missing answers !") + sys.exit(1) if response.ok: msg = ( FLAG_EMOJI diff --git a/yt/utilities/answer_testing/framework.py b/yt/utilities/answer_testing/framework.py index c66c5a74a8..721846cc44 100644 --- a/yt/utilities/answer_testing/framework.py +++ b/yt/utilities/answer_testing/framework.py @@ -278,6 +278,7 @@ def get(self, ds_name, default=None): return default # Read data using shelve answer_name = f"{ds_name}" + os.makedirs(os.path.dirname(self.reference_name), exist_ok=True) ds = shelve.open(self.reference_name, protocol=-1) try: result = ds[answer_name] From 02259d80d09ef4369d7b39dc640d2f022428b3c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 4 May 2022 11:25:13 +0200 Subject: [PATCH 14/54] bump answer store --- answer-store | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/answer-store b/answer-store index 997baaad97..dafcb51e89 160000 --- a/answer-store +++ b/answer-store @@ -1 +1 @@ -Subproject commit 997baaad97a69b04226e4e1a31171860eb38b491 +Subproject commit dafcb51e89e88acb03eaa8312ef96ffc2f4f4e04 From 51f94d27105855064a9d0789ce90c2d629298e03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 4 May 2022 12:00:55 +0200 Subject: [PATCH 15/54] bump answer store (new answers) --- answer-store | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/answer-store b/answer-store index dafcb51e89..d9c1557f7d 160000 --- a/answer-store +++ b/answer-store @@ -1 +1 @@ -Subproject commit dafcb51e89e88acb03eaa8312ef96ffc2f4f4e04 +Subproject commit d9c1557f7dba849a4446c11e2239b1822d1b7fde From c54167034271496cadc4f384dadd9bf45b261f6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 4 May 2022 14:43:20 +0200 Subject: [PATCH 16/54] bump answers on Jenkins --- tests/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests.yaml b/tests/tests.yaml index deb3b67fe5..f2f81ef4a6 100644 --- a/tests/tests.yaml +++ b/tests/tests.yaml @@ -95,7 +95,7 @@ answer_tests: - yt/frontends/owls/tests/test_outputs.py:test_snapshot_033 - yt/frontends/owls/tests/test_outputs.py:test_OWLS_particlefilter - local_pw_044: # PR 3640 + local_pw_045: # PR 3849 - yt/visualization/tests/test_plotwindow.py:test_attributes - yt/visualization/tests/test_particle_plot.py:test_particle_projection_answers - yt/visualization/tests/test_particle_plot.py:test_particle_projection_filter From 5f1e8356d08ab1db3dc0638550febb6090530b1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sat, 14 May 2022 17:45:16 +0200 Subject: [PATCH 17/54] Apply suggestions from code review Co-authored-by: Chris Havlin --- doc/source/visualizing/plots.rst | 6 +++--- yt/visualization/_commons.py | 2 +- yt/visualization/plot_container.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/source/visualizing/plots.rst b/doc/source/visualizing/plots.rst index eb7aa0789c..c559c2a300 100644 --- a/doc/source/visualizing/plots.rst +++ b/doc/source/visualizing/plots.rst @@ -950,7 +950,7 @@ method to set a custom colormap range. ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") slc = yt.SlicePlot(ds, "z", ("gas", "density"), width=(10, "kpc")) - slc.set_zlim(("gas", "density"), zmin=(1e-30, "g/cm**3)), zmax=(1e-25, "g/cm**3")) + slc.set_zlim(("gas", "density"), zmin=(1e-30, "g/cm**3"), zmax=(1e-25, "g/cm**3")) slc.save() Units can be left out, in which case they implicitly match the current display @@ -958,7 +958,7 @@ units of the colorbar (controlled with the ``set_unit`` method, see :ref:`_set-image-units`). Both ``zmin`` and ``zmax`` are optional, but note that they respectively default -to 0 and 1, which can be widely inapropriate, so it is recommended to specify both. +to 0 and 1, which can be widely inappropriate, so it is recommended to specify both. :meth:`~yt.visualization.plot_container.ImagePlotContainer.set_log` takes a boolean argument to select log (``True``) or linear (``False``) scalings. @@ -981,7 +981,7 @@ by providing a "linear threshold" (``linthresh``) import yt ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") - slc = yt.SlicePlot(ds, "z", ("gas", "velocity_x"), width=(10, "kpc")) + slc = yt.SlicePlot(ds, "z", ("gas", "density"), width=(10, "kpc")) slc.set_log(("gas", "density"), linthresh=(1e-26, "g/cm**3")) slc.save() diff --git a/yt/visualization/_commons.py b/yt/visualization/_commons.py index ae4f503d53..ff63dbf872 100644 --- a/yt/visualization/_commons.py +++ b/yt/visualization/_commons.py @@ -212,7 +212,7 @@ def _swap_arg_pair_order(*args): return tuple(new_args) -def get_log_minorticks(vmin, vmax): +def get_log_minorticks(vmin: float, vmax: float) -> np.ndarray: """calculate positions of linear minorticks on a log colorbar Parameters diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index 7107c03391..eab1788b37 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -164,7 +164,7 @@ def set_log( if field == 'all', applies to all plots. log : boolean, optional Log on/off: on means log scaling; off means linear scaling. - linthresh : float, or 'auto', optional + linthresh : float, (float, str), or 'auto', optional when using symlog scaling, linthresh is the value at which scaling transitions from linear to logarithmic. linthresh must be positive. Note: setting linthresh will automatically enable symlog scale From 43113bc67ac9fdd195646e5052efcb43b7cdc8fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sat, 14 May 2022 17:48:02 +0200 Subject: [PATCH 18/54] DOC: add missing docstring --- yt/visualization/_commons.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/yt/visualization/_commons.py b/yt/visualization/_commons.py index ff63dbf872..c7de3be68a 100644 --- a/yt/visualization/_commons.py +++ b/yt/visualization/_commons.py @@ -273,6 +273,18 @@ def get_symlog_minorticks(linthresh: float, vmin: float, vmax: float) -> np.ndar def get_symlog_majorticks(linthresh: float, vmin: float, vmax: float) -> np.ndarray: + """calculate positions of major ticks on a log colorbar + + Parameters + ---------- + linthresh : float + the threshold for the linear region + vmin : float + the minimum value in the colorbar + vmax : float + the maximum value in the colorbar + + """ if vmin >= 0.0: yticks = [vmin] + list( 10 From cc57aa2ae67ec11bc4dcb4c25b0379ffbfbb6c06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sat, 14 May 2022 17:57:17 +0200 Subject: [PATCH 19/54] TYP: complete a type hint, add dependency to 'typing_extensions' for Python < 3.8 --- setup.cfg | 1 + yt/_typing.py | 5 +++++ yt/visualization/plot_container.py | 10 ++++++++-- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index d0cbea441d..b44bf6eb9d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,6 +48,7 @@ install_requires = tomli-w>=0.4.0 tqdm>=3.4.0 unyt>=2.8.0 + typing-extensions>=4.2.0;python_version < '3.8' python_requires = >=3.7,<3.12 include_package_data = True scripts = scripts/iyt diff --git a/yt/_typing.py b/yt/_typing.py index 4b76c9899f..4b2c374cfc 100644 --- a/yt/_typing.py +++ b/yt/_typing.py @@ -1,6 +1,7 @@ from typing import List, Optional, Tuple, Union from numpy import ndarray +from unyt import unyt_quantity FieldDescT = Tuple[str, Tuple[str, List[str], Optional[str]]] KnownFieldsT = Tuple[FieldDescT, ...] @@ -10,3 +11,7 @@ Tuple[ndarray, ndarray, ndarray], # xyz Union[float, ndarray], # hsml ] + +# an intentionally restrictive list of types that can +# be passes to ds.quan (which is a proxy for unyt.unyt_quantity.__init__) +Quantity = Union[unyt_quantity, Tuple[float, str]] diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index eab1788b37..f7a7b5738c 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -15,6 +15,7 @@ from unyt.dimensions import length from yt._maintenance.deprecation import issue_deprecation_warning +from yt._typing import Quantity from yt.config import ytcfg from yt.data_objects.time_series import DatasetSeries from yt.funcs import ensure_dir, is_sequence, iter_fields @@ -34,6 +35,11 @@ validate_plot, ) +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + latex_prefixes = { "u": r"\mu", } @@ -145,7 +151,7 @@ def set_log( field, log: Optional[bool] = None, *, - linthresh: Optional[Union[float, str]] = None, + linthresh: Optional[Union[float, Quantity, Literal["auto"]]] = None, symlog_auto: Optional[bool] = None, # deprecated ): """set a field to log, linear, or symlog. @@ -164,7 +170,7 @@ def set_log( if field == 'all', applies to all plots. log : boolean, optional Log on/off: on means log scaling; off means linear scaling. - linthresh : float, (float, str), or 'auto', optional + linthresh : float, (float, str), unyt_quantity, or 'auto', optional when using symlog scaling, linthresh is the value at which scaling transitions from linear to logarithmic. linthresh must be positive. Note: setting linthresh will automatically enable symlog scale From 8f7433f38380c5d9aa9cf7b79acb2a7efb957eb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sat, 14 May 2022 18:13:30 +0200 Subject: [PATCH 20/54] DOC: rework symlog demo docs --- doc/source/visualizing/plots.rst | 35 ++++++++++++++++---------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/doc/source/visualizing/plots.rst b/doc/source/visualizing/plots.rst index c559c2a300..ca44bca37f 100644 --- a/doc/source/visualizing/plots.rst +++ b/doc/source/visualizing/plots.rst @@ -974,26 +974,23 @@ to select log (``True``) or linear (``False``) scalings. One can switch to `symlog `_ -by providing a "linear threshold" (``linthresh``) +by providing a "linear threshold" (``linthresh``) value. +With ``linthresh="auto"`` yt will switch to symlog norm and guess an appropriate value +automatically. Specifically the minimum absolute value in the image is used +unless it's zero, in which case yt uses 1/1000 of the maximum value. .. python-script:: import yt - ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") + ds = yt.load_sample("IsolatedGalaxy") slc = yt.SlicePlot(ds, "z", ("gas", "density"), width=(10, "kpc")) - slc.set_log(("gas", "density"), linthresh=(1e-26, "g/cm**3")) + slc.set_log(("gas", "density"), linthresh="auto") slc.save() -Similar to the ``zmin`` and ``zmax`` arguments of the ``set_zlim`` method, units -can be left out in ``linthresh``. - ``linthresh="auto"`` is also -valid, in this case yt will switch to symlog norm and guess an appropriate value -automatically. Specifically the minimum absolute value in the image is used -unless it's zero, in which case yt uses 1/1000 of the maximum value. - -As an example, +In some cases, you might find that the automatically selected linear threshold is not +really suited to your dataset, for instance .. python-script:: @@ -1004,18 +1001,20 @@ As an example, p.set_log(("gas", "density"), linthresh="auto") p.save() -Symlog is very versatile, and will work with positive or negative dataset ranges. -Here is an example using symlog scaling to plot a positive field with a linear range of -``(0, linthresh)``. +An explicit value can be passed instead .. python-script:: import yt - ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") - slc = yt.SlicePlot(ds, "z", ("gas", "velocity_x"), width=(30, "kpc")) - slc.set_log(("gas", "velocity_x"), linthresh=1.0e1) - slc.save() + ds = yt.load_sample("FIRE_M12i_ref11") + p = yt.ProjectionPlot(ds, "x", ("gas", "density")) + p.set_log(("gas", "density"), linthresh=(1e-22, "g/cm**2")) + p.save() + +Similar to the ``zmin`` and ``zmax`` arguments of the ``set_zlim`` method, units +can be left out in ``linthresh``. + **Arbitrary norms** From 0a443cd3f162a6f9554a57b5717481b3b9eec02d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sat, 14 May 2022 18:53:55 +0200 Subject: [PATCH 21/54] TYP: fix typing --- yt/visualization/_commons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yt/visualization/_commons.py b/yt/visualization/_commons.py index c7de3be68a..9e8cd3e10f 100644 --- a/yt/visualization/_commons.py +++ b/yt/visualization/_commons.py @@ -227,7 +227,7 @@ def get_log_minorticks(vmin: float, vmax: float) -> np.ndarray: expB = np.floor(np.log10(vmax)) cofA = np.ceil(vmin / 10**expA).astype("int64") cofB = np.floor(vmax / 10**expB).astype("int64") - lmticks = [] + lmticks = np.empty(0) while cofA * 10**expA <= cofB * 10**expB: if expA < expB: lmticks = np.hstack((lmticks, np.linspace(cofA, 9, 10 - cofA) * 10**expA)) From 1a10ec7e3d42d4edb06673ee618ad72584f54ce6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sat, 14 May 2022 19:29:47 +0200 Subject: [PATCH 22/54] TYP: typing improvements for yt.visualization._handlers.py --- yt/_typing.py | 11 +++++++---- yt/visualization/_handlers.py | 35 ++++++++++++++++++----------------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/yt/_typing.py b/yt/_typing.py index 4b2c374cfc..fd248f343d 100644 --- a/yt/_typing.py +++ b/yt/_typing.py @@ -1,7 +1,7 @@ from typing import List, Optional, Tuple, Union +import unyt as un from numpy import ndarray -from unyt import unyt_quantity FieldDescT = Tuple[str, Tuple[str, List[str], Optional[str]]] KnownFieldsT = Tuple[FieldDescT, ...] @@ -12,6 +12,9 @@ Union[float, ndarray], # hsml ] -# an intentionally restrictive list of types that can -# be passes to ds.quan (which is a proxy for unyt.unyt_quantity.__init__) -Quantity = Union[unyt_quantity, Tuple[float, str]] + +# types that can be converted to un.Unit +Unit = Union[un.Unit, str] + +# types that can be converted to un.unyt_quantity +Quantity = Union[un.unyt_quantity, Tuple[float, Unit]] diff --git a/yt/visualization/_handlers.py b/yt/visualization/_handlers.py index 61a6f35ce8..1e1f946554 100644 --- a/yt/visualization/_handlers.py +++ b/yt/visualization/_handlers.py @@ -3,11 +3,12 @@ from typing import Any, Dict, List, Optional, Type, Union import numpy as np +import unyt as un from matplotlib.cm import get_cmap from matplotlib.colors import Colormap, LogNorm, Normalize, SymLogNorm from packaging.version import Version -from unyt import Unit, unyt_quantity +from yt._typing import Quantity, Unit from yt.config import ytcfg from yt.funcs import get_brewer_cmap, is_sequence, mylog from yt.visualization._commons import MPL_VERSION @@ -45,9 +46,9 @@ def __init__( self, data_source, *, - display_units: Unit, - vmin: Optional[unyt_quantity] = None, - vmax: Optional[unyt_quantity] = None, + display_units: un.Unit, + vmin: Optional[un.unyt_quantity] = None, + vmax: Optional[un.unyt_quantity] = None, norm_type: Optional[Type[Normalize]] = None, norm: Optional[Normalize] = None, linthresh: Optional[float] = None, @@ -99,17 +100,17 @@ def _reset_norm(self): mylog.warning("Dropping norm (%s)", self.norm) self._norm = None - def to_float(self, val: unyt_quantity) -> float: + def to_float(self, val: un.unyt_quantity) -> float: return float(val.to(self.display_units).d) - def to_quan(self, val) -> unyt_quantity: - if isinstance(val, unyt_quantity): + def to_quan(self, val) -> un.unyt_quantity: + if isinstance(val, un.unyt_quantity): return self.ds.quan(val) elif ( is_sequence(val) and len(val) == 2 and isinstance(val[0], Real) - and isinstance(val[1], (str, Unit)) + and isinstance(val[1], (str, un.Unit)) ): return self.ds.quan(*val) elif isinstance(val, Real): @@ -118,15 +119,15 @@ def to_quan(self, val) -> unyt_quantity: raise TypeError(f"Could not convert {val!r} to unyt_quantity") @property - def display_units(self) -> Unit: + def display_units(self) -> un.Unit: return self._display_units @display_units.setter - def display_units(self, newval: Union[str, Unit]) -> None: - self._display_units = Unit(newval) + def display_units(self, newval: Unit) -> None: + self._display_units = un.Unit(newval, registry=self.ds.unit_registry) def _set_quan_attr( - self, attr: str, newval: Optional[Union[unyt_quantity, float]] + self, attr: str, newval: Optional[Union[Quantity, float]] ) -> None: if newval is None: setattr(self, attr, None) @@ -144,20 +145,20 @@ def _set_quan_attr( setattr(self, attr, quan) @property - def vmin(self) -> Optional[unyt_quantity]: + def vmin(self) -> Optional[un.unyt_quantity]: return self._vmin @vmin.setter - def vmin(self, newval: Optional[Union[unyt_quantity, float]]) -> None: + def vmin(self, newval: Optional[Union[Quantity, float]]) -> None: self._reset_norm() self._set_quan_attr("_vmin", newval) @property - def vmax(self) -> Optional[unyt_quantity]: + def vmax(self) -> Optional[un.unyt_quantity]: return self._vmax @vmax.setter - def vmax(self, newval: Optional[Union[unyt_quantity, float]]) -> None: + def vmax(self, newval: Optional[Union[Quantity, float]]) -> None: self._reset_norm() self._set_quan_attr("_vmax", newval) @@ -199,7 +200,7 @@ def linthresh(self) -> Optional[float]: return self._linthresh @linthresh.setter - def linthresh(self, newval: Optional[Union[unyt_quantity, float]]) -> None: + def linthresh(self, newval: Optional[Union[Quantity, float]]) -> None: self._reset_norm() self._set_quan_attr("_linthresh", newval) if self._linthresh is not None and self._linthresh <= 0: From a3a5c16f0a904e3d07f9def6b0b62d81b814be9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 18 May 2022 09:43:50 +0200 Subject: [PATCH 23/54] DOC: fix typos Co-authored-by: Chris Havlin --- doc/source/visualizing/plots.rst | 6 +++--- yt/visualization/_handlers.py | 4 ++-- yt/visualization/plot_container.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/doc/source/visualizing/plots.rst b/doc/source/visualizing/plots.rst index ca44bca37f..dedc867658 100644 --- a/doc/source/visualizing/plots.rst +++ b/doc/source/visualizing/plots.rst @@ -939,7 +939,7 @@ The norm properties can be constrained via two methods of the value range ``zmin`` and ``zmax``. - :meth:`~yt.visualization.plot_container.ImagePlotContainer.set_log` allows switching to linear or symlog norms. With symlog, the linear threshold can be set - explicitly. Otherwise, yt will dynamically determine a resonable value. + explicitly. Otherwise, yt will dynamically determine a reasonable value. Use the :meth:`~yt.visualization.plot_container.PlotContainer.set_zlim` method to set a custom colormap range. @@ -997,7 +997,7 @@ really suited to your dataset, for instance import yt ds = yt.load_sample("FIRE_M12i_ref11") - p = yt.ProjectionPlot(ds, "x", ("gas", "density")) + p = yt.ProjectionPlot(ds, "x", ("gas", "density"), width=(30, "Mpc")) p.set_log(("gas", "density"), linthresh="auto") p.save() @@ -1008,7 +1008,7 @@ An explicit value can be passed instead import yt ds = yt.load_sample("FIRE_M12i_ref11") - p = yt.ProjectionPlot(ds, "x", ("gas", "density")) + p = yt.ProjectionPlot(ds, "x", ("gas", "density"), width=(30, "Mpc")) p.set_log(("gas", "density"), linthresh=(1e-22, "g/cm**2")) p.save() diff --git a/yt/visualization/_handlers.py b/yt/visualization/_handlers.py index 1e1f946554..9ba9c57edf 100644 --- a/yt/visualization/_handlers.py +++ b/yt/visualization/_handlers.py @@ -205,7 +205,7 @@ def linthresh(self, newval: Optional[Union[Quantity, float]]) -> None: self._set_quan_attr("_linthresh", newval) if self._linthresh is not None and self._linthresh <= 0: raise ValueError( - f"linthresh can only be set to stricly positive values, got {newval}" + f"linthresh can only be set to strictly positive values, got {newval}" ) if newval is not None: self.norm_type = SymLogNorm @@ -234,7 +234,7 @@ def get_norm(self, data: np.ndarray, *args, **kw) -> Normalize: min_abs_val, max_abs_val = np.sort(np.abs((kw["vmin"], kw["vmax"]))) if self.norm_type is not None: # this is a convenience mechanism for backward compat, - # allowing to toggle between lin and log scaling without detailled user input + # allowing to toggle between lin and log scaling without detailed user input norm_type = self.norm_type else: if kw["vmin"] == kw["vmax"] or not np.any(np.isfinite(data)): diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index f7a7b5738c..a7bd0d6306 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -169,7 +169,7 @@ def set_log( the field to set a transform if field == 'all', applies to all plots. log : boolean, optional - Log on/off: on means log scaling; off means linear scaling. + set log to True for log scaling, False for linear scaling. linthresh : float, (float, str), unyt_quantity, or 'auto', optional when using symlog scaling, linthresh is the value at which scaling transitions from linear to logarithmic. linthresh must be positive. @@ -271,12 +271,12 @@ def set_norm(self, field, norm: Normalize): r""" Set a custom ``matplotlib.colors.Normalize`` to plot *field*. - Any constraints previously set with `set.log`, `set.zlim` will be + Any constraints previously set with `set_log`, `set_zlim` will be dropped. Note that any float value attached to *norm* (e.g. vmin, vmax, vcenter ...) will be read in the current displayed units, which can be - controlled with the `set_units` method. + controlled with the `set_unit` method. Parameters ---------- From 452f268efb32697927caab3e5a8c64ab9399fe45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 18 May 2022 12:15:57 +0200 Subject: [PATCH 24/54] EXP: try simplifying a handler method Co-authored-by: Chris Havlin --- yt/visualization/_handlers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/yt/visualization/_handlers.py b/yt/visualization/_handlers.py index 9ba9c57edf..0a9663ce7d 100644 --- a/yt/visualization/_handlers.py +++ b/yt/visualization/_handlers.py @@ -131,8 +131,6 @@ def _set_quan_attr( ) -> None: if newval is None: setattr(self, attr, None) - elif isinstance(newval, Real): - setattr(self, attr, newval * self.display_units) else: try: quan = self.to_quan(newval) From 653185fb96237a9046da2d44380c026215bd4c16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 18 May 2022 13:37:45 +0200 Subject: [PATCH 25/54] RFC: drop redundant loops --- yt/visualization/plot_container.py | 44 +++++++++++------------------- 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index a7bd0d6306..4acdd51bf5 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -285,15 +285,8 @@ def set_norm(self, field, norm: Normalize): norm : matplotlib.colors.Normalize see https://matplotlib.org/stable/tutorials/colors/colormapnorms.html """ - - if field == "all": - fields = list(self.plots.keys()) - else: - fields = field - - for field in self.data_source._determine_fields(fields): - pnh = self.plots[field].norm_handler - pnh.norm = norm + pnh = self.plots[field].norm_handler + pnh.norm = norm return self @accepts_all_fields @@ -989,27 +982,22 @@ def set_zlim(self, field, zmin=None, zmax=None, dynamic_range=None): zmin = zmax / dynamic_range. """ - - if field == "all": - fields = list(self.plots.keys()) - else: - fields = field if zmin is None and zmax is None: raise TypeError("Missing required argument zmin or zmax") - for field in self.data_source._determine_fields(fields): - if dynamic_range is not None: - if zmax is None and zmin is not None: - zmax = zmin * dynamic_range - elif zmin is None and zmax is not None: - zmin = zmax / dynamic_range - else: - raise TypeError( - "Using dynamic_range requires that either zmin or zmax " - "be specified, but not both." - ) - pnh = self.plots[field].norm_handler - pnh.vmin = zmin - pnh.vmax = zmax + + if dynamic_range is not None: + if zmax is None and zmin is not None: + zmax = zmin * dynamic_range + elif zmin is None and zmax is not None: + zmin = zmax / dynamic_range + else: + raise TypeError( + "Using dynamic_range requires that either zmin or zmax " + "be specified, but not both." + ) + pnh = self.plots[field].norm_handler + pnh.vmin = zmin + pnh.vmax = zmax return self @accepts_all_fields From 9072004c3ef233497cfd77d643b21d0587669e92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 18 May 2022 16:25:07 +0200 Subject: [PATCH 26/54] MNT: cleanup dead code --- yt/visualization/plot_container.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index 4acdd51bf5..671dde45aa 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -9,7 +9,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union import matplotlib -import numpy as np from matplotlib.colors import LogNorm, Normalize, SymLogNorm from matplotlib.font_manager import FontProperties from unyt.dimensions import length @@ -74,24 +73,6 @@ def newfunc(self, field, *args, **kwargs): return newfunc -field_transforms = {} - - -class FieldTransform: - def __init__(self, name, func): - self.name = name - self.func = func - field_transforms[name] = self - - def __call__(self, *args, **kwargs): - return self.func(*args, **kwargs) - - -log_transform = FieldTransform("log10", np.log10) -linear_transform = FieldTransform("linear", lambda x: x) -symlog_transform = FieldTransform("symlog", None) - - class PlotDictionary(defaultdict): def __getitem__(self, item): return defaultdict.__getitem__( From ea080196847612382f57dd02953f7cf0e1ef883b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 18 May 2022 16:53:45 +0200 Subject: [PATCH 27/54] TST: add new tests --- yt/visualization/tests/test_set_zlim.py | 51 +++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 yt/visualization/tests/test_set_zlim.py diff --git a/yt/visualization/tests/test_set_zlim.py b/yt/visualization/tests/test_set_zlim.py new file mode 100644 index 0000000000..13814d3cd5 --- /dev/null +++ b/yt/visualization/tests/test_set_zlim.py @@ -0,0 +1,51 @@ +import numpy as np +import numpy.testing as npt + +from yt.testing import fake_amr_ds +from yt.visualization.api import SlicePlot + + +def test_float_vmin_then_set_unit(): + # this test doesn't represent how users should interact with plot containers + # in particular it uses the `_setup_plots()` private method, as a quick way to + # create a plot without having to make it an answer test + field = ("gas", "density") + ds = fake_amr_ds(fields=[field], units=["g/cm**3"]) + + p = SlicePlot(ds, "x", field) + p.set_buff_size(16) + + p._setup_plots() + cb = p.plots[field].image.colorbar + raw_lims = np.array((cb.vmin, cb.vmax)) + desired_lims = raw_lims.copy() + desired_lims[0] = 1e-2 + + p.set_zlim(field, zmin=desired_lims[0]) + + p._setup_plots() + cb = p.plots[field].image.colorbar + new_lims = np.array((cb.vmin, cb.vmax)) + npt.assert_almost_equal(new_lims, desired_lims) + + # 1 g/cm**3 == 1000 kg/m**3 + p.set_unit(field, "kg/m**3") + p._setup_plots() + + cb = p.plots[field].image.colorbar + new_lims = np.array((cb.vmin, cb.vmax)) + npt.assert_almost_equal(new_lims, 1000 * desired_lims) + + +def test_set_unit_then_float_vmin(): + field = ("gas", "density") + ds = fake_amr_ds(fields=[field], units=["g/cm**3"]) + + p = SlicePlot(ds, "x", field) + p.set_buff_size(16) + + p.set_unit(field, "kg/m**3") + p.set_zlim(field, zmin=1) + p._setup_plots() + cb = p.plots[field].image.colorbar + assert cb.vmin == 1.0 From 24c53882ac50b0b9394b4c38d2e18fba96efc967 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 18 May 2022 17:00:33 +0200 Subject: [PATCH 28/54] DOC: add an important note to set_norm docs --- doc/source/visualizing/plots.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/doc/source/visualizing/plots.rst b/doc/source/visualizing/plots.rst index dedc867658..b43c778e74 100644 --- a/doc/source/visualizing/plots.rst +++ b/doc/source/visualizing/plots.rst @@ -1040,6 +1040,15 @@ great responsibility. slc.set_cmap(("gas", "velocity_x"), "RdBu") slc.save() +.. note:: When calling + :meth:`~yt.visualization.plot_container.PlotContainer.set_norm`, any constraints + previously set with + :meth:`~yt.visualization.plot_container.PlotContainer.set_log` or + :meth:`~yt.visualization.plot_container.PlotContainer.set_zlim` will be dropped. + Conversely, calling ``set_log`` or ``set_zlim`` will have the + effect of dropping any norm previously set via ``set_norm``. + + The :meth:`~yt.visualization.plot_container.ImagePlotContainer.set_background_color` function accepts a field name and a color (optional). If color is given, the function will set the plot's background color to that. If not, it will set it to the bottom From 7af2bbeac98186855e75cbba9b25d5a1cd03791c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 18 May 2022 17:17:36 +0200 Subject: [PATCH 29/54] TST: add new test --- yt/visualization/tests/test_set_zlim.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/yt/visualization/tests/test_set_zlim.py b/yt/visualization/tests/test_set_zlim.py index 13814d3cd5..53ef69cf6c 100644 --- a/yt/visualization/tests/test_set_zlim.py +++ b/yt/visualization/tests/test_set_zlim.py @@ -49,3 +49,27 @@ def test_set_unit_then_float_vmin(): p._setup_plots() cb = p.plots[field].image.colorbar assert cb.vmin == 1.0 + + +def test_reset_zlim(): + field = ("gas", "density") + ds = fake_amr_ds(fields=[field], units=["g/cm**3"]) + + p = SlicePlot(ds, "x", field) + p.set_buff_size(16) + + p._setup_plots() + cb = p.plots[field].image.colorbar + raw_lims = np.array((cb.vmin, cb.vmax)) + + # set a new zin value + delta = np.diff(raw_lims)[0] + p.set_zlim(field, zmin=raw_lims[0] + delta / 2) + + # passing a None explicitly should restore default limit + p.set_zlim(field, zmin=None) + p._setup_plots() + + cb = p.plots[field].image.colorbar + new_lims = np.array((cb.vmin, cb.vmax)) + npt.assert_array_equal(new_lims, raw_lims) From f59c0eddf649b1ecf9e94320a6d7a80000c28020 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 18 May 2022 17:40:36 +0200 Subject: [PATCH 30/54] BUG: fix issue with resetting zlim by passing None explicitly --- yt/visualization/plot_container.py | 54 ++++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index 671dde45aa..a0f271b2f1 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -35,9 +35,9 @@ ) if sys.version_info >= (3, 8): - from typing import Literal + from typing import Final, Literal else: - from typing_extensions import Literal + from typing_extensions import Final, Literal latex_prefixes = { "u": r"\mu", @@ -73,6 +73,19 @@ def newfunc(self, field, *args, **kwargs): return newfunc +# define a singleton sentinel to be used as default value distinct from None +class Unset: + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = object.__new__(cls) + return cls._instance + + +UNSET: Final = Unset() + + class PlotDictionary(defaultdict): def __getitem__(self, item): return defaultdict.__getitem__( @@ -938,7 +951,13 @@ def set_background_color(self, field, color=None): @accepts_all_fields @invalidate_plot - def set_zlim(self, field, zmin=None, zmax=None, dynamic_range=None): + def set_zlim( + self, + field, + zmin: Union[float, Quantity, Literal["min"], None, Unset] = UNSET, + zmax: Union[float, Quantity, Literal["max"], None, Unset] = UNSET, + dynamic_range=None, + ): """set the scale of the colormap Parameters @@ -946,11 +965,11 @@ def set_zlim(self, field, zmin=None, zmax=None, dynamic_range=None): field : string the field to set a colormap scale if field == 'all', applies to all plots. - zmin : float, tuple, YTQuantity or str - the new minimum of the colormap scale. If 'min', will + zmin : float, Quantity, None or 'min' + the new minimum of the colormap scale. If None or 'min', will set to the minimum value in the current view. - zmax : float, tuple, YTQuantity or str - the new maximum of the colormap scale. If 'max', will + zmax : float, Quantity, None or 'max' + the new maximum of the colormap scale. If None or 'max', will set to the maximum value in the current view. Other Parameters @@ -963,19 +982,32 @@ def set_zlim(self, field, zmin=None, zmax=None, dynamic_range=None): zmin = zmax / dynamic_range. """ - if zmin is None and zmax is None: + if zmin is UNSET and zmax is UNSET: raise TypeError("Missing required argument zmin or zmax") + if zmin == "min": + zmin = None + + if zmax == "max": + zmax = None + if dynamic_range is not None: - if zmax is None and zmin is not None: + if zmax is UNSET and zmin is not UNSET: zmax = zmin * dynamic_range - elif zmin is None and zmax is not None: + elif zmin is UNSET and zmax is not UNSET: zmin = zmax / dynamic_range else: raise TypeError( "Using dynamic_range requires that either zmin or zmax " - "be specified, but not both." + "be explicitly specified, but not both." ) + + if zmin is UNSET: + zmin = None + + if zmax is UNSET: + zmax = None + pnh = self.plots[field].norm_handler pnh.vmin = zmin pnh.vmax = zmax From ed9bfaca202fcc00638c52aee8363ca3d8e3b609 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 18 May 2022 17:49:35 +0200 Subject: [PATCH 31/54] TYP: fix a couple type hints --- yt/visualization/plot_container.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index a0f271b2f1..9133e9e9d1 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -117,7 +117,7 @@ class PlotContainer(abc.ABC): _default_figure_size = tuple(matplotlib.rcParams["figure.figsize"]) _default_font_size = 14.0 - def __init__(self, data_source, figure_size=None, fontsize: float = None): + def __init__(self, data_source, figure_size=None, fontsize: Optional[float] = None): self.data_source = data_source self.ds = data_source.ds self.ts = self._initialize_dataset(self.ds) @@ -956,7 +956,7 @@ def set_zlim( field, zmin: Union[float, Quantity, Literal["min"], None, Unset] = UNSET, zmax: Union[float, Quantity, Literal["max"], None, Unset] = UNSET, - dynamic_range=None, + dynamic_range: Optional[float] = None, ): """set the scale of the colormap From fe3ee5f8494b3849e30d67830670a3a2a7878510 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 18 May 2022 22:24:51 +0200 Subject: [PATCH 32/54] ENH: move dynamic_range setting logic to NormHandler --- yt/visualization/_handlers.py | 114 +++++++++++++++++++++--- yt/visualization/plot_container.py | 25 +----- yt/visualization/tests/test_set_zlim.py | 53 ++++++++++- 3 files changed, 156 insertions(+), 36 deletions(-) diff --git a/yt/visualization/_handlers.py b/yt/visualization/_handlers.py index 0a9663ce7d..5f55cdb242 100644 --- a/yt/visualization/_handlers.py +++ b/yt/visualization/_handlers.py @@ -1,6 +1,7 @@ +import sys import weakref from numbers import Real -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import numpy as np import unyt as un @@ -13,6 +14,11 @@ from yt.funcs import get_brewer_cmap, is_sequence, mylog from yt.visualization._commons import MPL_VERSION +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class NormHandler: """ @@ -35,12 +41,19 @@ class NormHandler: "_display_units", "_vmin", "_vmax", + "_dynamic_range", "_norm_type", "_linthresh", "_norm_type", "_norm", ) - _constraint_attrs: List[str] = ["vmin", "vmax", "norm_type", "linthresh"] + _constraint_attrs: List[str] = [ + "vmin", + "vmax", + "dynamic_range", + "norm_type", + "linthresh", + ] def __init__( self, @@ -49,6 +62,7 @@ def __init__( display_units: un.Unit, vmin: Optional[un.unyt_quantity] = None, vmax: Optional[un.unyt_quantity] = None, + dynamic_range: Optional[float] = None, norm_type: Optional[Type[Normalize]] = None, norm: Optional[Normalize] = None, linthresh: Optional[float] = None, @@ -60,6 +74,7 @@ def __init__( self._norm = norm self._vmin = vmin self._vmax = vmax + self._dynamic_range = dynamic_range self._norm_type = norm_type self._linthresh = linthresh @@ -143,22 +158,86 @@ def _set_quan_attr( setattr(self, attr, quan) @property - def vmin(self) -> Optional[un.unyt_quantity]: + def vmin(self) -> Optional[Union[un.unyt_quantity, Literal["min"]]]: return self._vmin @vmin.setter - def vmin(self, newval: Optional[Union[Quantity, float]]) -> None: + def vmin(self, newval: Optional[Union[Quantity, float, Literal["min"]]]) -> None: self._reset_norm() - self._set_quan_attr("_vmin", newval) + if newval == "min": + self._vmin = "min" + else: + self._set_quan_attr("_vmin", newval) @property - def vmax(self) -> Optional[un.unyt_quantity]: + def vmax(self) -> Optional[Union[un.unyt_quantity, Literal["max"]]]: return self._vmax @vmax.setter - def vmax(self, newval: Optional[Union[Quantity, float]]) -> None: + def vmax(self, newval: Optional[Union[Quantity, float, Literal["max"]]]) -> None: + self._reset_norm() + if newval == "max": + self._vmax = "max" + else: + self._set_quan_attr("_vmax", newval) + + @property + def dynamic_range(self) -> Optional[float]: + return self._dynamic_range + + @dynamic_range.setter + def dynamic_range(self, newval: Optional[float]) -> None: + if newval is None: + return + + try: + newval = float(newval) + except TypeError: + raise TypeError( + "Expected a float. " f"Received {newval} with type {type(newval)}" + ) from None + + if newval <= 0: + raise ValueError( + f"Dynamic range must be strictly positive. Received {newval}" + ) + + if newval == 1: + raise ValueError("Dynamic range cannot be unity.") + self._reset_norm() - self._set_quan_attr("_vmax", newval) + self._dynamic_range = newval + + def get_dynamic_range( + self, dvmin: Optional[float], dvmax: Optional[float] + ) -> Tuple[float, float]: + if self.dynamic_range is None: + raise RuntimeError( + "Something went terribly wrong in setting up a dynamic range" + ) + + if self.vmax is None: + if self.vmin is None: + raise TypeError( + "Cannot set dynamic range with neither " + "vmin and vmax being constrained." + ) + if dvmin is None: + raise RuntimeError( + "Something went terribly wrong in setting up a dynamic range" + ) + return dvmin, dvmin * self.dynamic_range + elif self.vmin is None: + if dvmax is None: + raise RuntimeError( + "Something went terribly wrong in setting up a dynamic range" + ) + return dvmax / self.dynamic_range, dvmax + else: + raise TypeError( + "Cannot set dynamic range with both " + "vmin and vmax already constrained." + ) @property def norm_type(self) -> Optional[Type[Normalize]]: @@ -212,20 +291,27 @@ def get_norm(self, data: np.ndarray, *args, **kw) -> Normalize: if self.has_norm: return self.norm + dvmin = dvmax = None + finite_values_mask = np.isfinite(data) - if self.vmin is not None: + if self.vmin not in (None, "min"): dvmin = self.to_float(self.vmin) elif np.any(finite_values_mask): dvmin = self.to_float(np.nanmin(data[finite_values_mask])) - else: - dvmin = 1 * getattr(data, "units", 1) - kw.setdefault("vmin", dvmin) - if self.vmax is not None: + if self.vmax not in (None, "max"): dvmax = self.to_float(self.vmax) elif np.any(finite_values_mask): dvmax = self.to_float(np.nanmax(data[finite_values_mask])) - else: + + if self.dynamic_range is not None: + dvmin, dvmax = self.get_dynamic_range(dvmin, dvmax) + + if dvmin is None: + dvmin = 1 * getattr(data, "units", 1) + kw.setdefault("vmin", dvmin) + + if dvmax is None: dvmax = 1 * getattr(data, "units", 1) kw.setdefault("vmax", dvmax) diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index 9133e9e9d1..41caecff26 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -954,8 +954,8 @@ def set_background_color(self, field, color=None): def set_zlim( self, field, - zmin: Union[float, Quantity, Literal["min"], None, Unset] = UNSET, - zmax: Union[float, Quantity, Literal["max"], None, Unset] = UNSET, + zmin: Union[float, Quantity, Literal[None, "min"], Unset] = UNSET, + zmax: Union[float, Quantity, Literal[None, "max"], Unset] = UNSET, dynamic_range: Optional[float] = None, ): """set the scale of the colormap @@ -978,30 +978,11 @@ def set_zlim( The dynamic range of the image. If zmin == None, will set zmin = zmax / dynamic_range If zmax == None, will set zmax = zmin * dynamic_range - When dynamic_range is specified, defaults to setting - zmin = zmax / dynamic_range. """ if zmin is UNSET and zmax is UNSET: raise TypeError("Missing required argument zmin or zmax") - if zmin == "min": - zmin = None - - if zmax == "max": - zmax = None - - if dynamic_range is not None: - if zmax is UNSET and zmin is not UNSET: - zmax = zmin * dynamic_range - elif zmin is UNSET and zmax is not UNSET: - zmin = zmax / dynamic_range - else: - raise TypeError( - "Using dynamic_range requires that either zmin or zmax " - "be explicitly specified, but not both." - ) - if zmin is UNSET: zmin = None @@ -1011,6 +992,8 @@ def set_zlim( pnh = self.plots[field].norm_handler pnh.vmin = zmin pnh.vmax = zmax + pnh.dynamic_range = dynamic_range + return self @accepts_all_fields diff --git a/yt/visualization/tests/test_set_zlim.py b/yt/visualization/tests/test_set_zlim.py index 53ef69cf6c..4b3d69b491 100644 --- a/yt/visualization/tests/test_set_zlim.py +++ b/yt/visualization/tests/test_set_zlim.py @@ -62,7 +62,7 @@ def test_reset_zlim(): cb = p.plots[field].image.colorbar raw_lims = np.array((cb.vmin, cb.vmax)) - # set a new zin value + # set a new zmin value delta = np.diff(raw_lims)[0] p.set_zlim(field, zmin=raw_lims[0] + delta / 2) @@ -73,3 +73,54 @@ def test_reset_zlim(): cb = p.plots[field].image.colorbar new_lims = np.array((cb.vmin, cb.vmax)) npt.assert_array_equal(new_lims, raw_lims) + + +def test_set_dynamic_range_with_vmin(): + field = ("gas", "density") + ds = fake_amr_ds(fields=[field], units=["g/cm**3"]) + + p = SlicePlot(ds, "x", field) + p.set_buff_size(16) + + zmin = 1e-2 + p.set_zlim(field, zmin=zmin, dynamic_range=2) + + p._setup_plots() + cb = p.plots[field].image.colorbar + new_lims = np.array((cb.vmin, cb.vmax)) + npt.assert_almost_equal(new_lims, (zmin, 2 * zmin)) + + +def test_set_dynamic_range_with_vmax(): + field = ("gas", "density") + ds = fake_amr_ds(fields=[field], units=["g/cm**3"]) + + p = SlicePlot(ds, "x", field) + p.set_buff_size(16) + + zmax = 1 + p.set_zlim(field, zmax=zmax, dynamic_range=2) + + p._setup_plots() + cb = p.plots[field].image.colorbar + new_lims = np.array((cb.vmin, cb.vmax)) + npt.assert_almost_equal(new_lims, (zmax / 2, zmax)) + + +def test_set_dynamic_range_with_min(): + field = ("gas", "density") + ds = fake_amr_ds(fields=[field], units=["g/cm**3"]) + + p = SlicePlot(ds, "x", field) + p.set_buff_size(16) + + p._setup_plots() + cb = p.plots[field].image.colorbar + vmin = cb.vmin + + p.set_zlim(field, zmin="min", dynamic_range=2) + + p._setup_plots() + cb = p.plots[field].image.colorbar + new_lims = np.array((cb.vmin, cb.vmax)) + npt.assert_almost_equal(new_lims, (vmin, 2 * vmin)) From bc16c0aac385e20d3e94eac0d7530e3399129574 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Thu, 19 May 2022 08:03:18 +0200 Subject: [PATCH 33/54] RFC: refactor comparison to avoid a deprecation warning in minimal supported env --- yt/visualization/_handlers.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/yt/visualization/_handlers.py b/yt/visualization/_handlers.py index 5f55cdb242..42d63463fc 100644 --- a/yt/visualization/_handlers.py +++ b/yt/visualization/_handlers.py @@ -294,12 +294,20 @@ def get_norm(self, data: np.ndarray, *args, **kw) -> Normalize: dvmin = dvmax = None finite_values_mask = np.isfinite(data) - if self.vmin not in (None, "min"): + # FUTURE: when the minimal supported version of numpy reaches 1.16 or newer, + # this complicated conditional can be simplified into + # if self.vmin not in (None, "min"): + if self.vmin is not None and not ( + isinstance(self.vmin, str) and self.vmin == "min" + ): dvmin = self.to_float(self.vmin) elif np.any(finite_values_mask): dvmin = self.to_float(np.nanmin(data[finite_values_mask])) - if self.vmax not in (None, "max"): + # FUTURE: see above + if self.vmax is not None and not ( + isinstance(self.vmax, str) and self.vmax == "max" + ): dvmax = self.to_float(self.vmax) elif np.any(finite_values_mask): dvmax = self.to_float(np.nanmax(data[finite_values_mask])) From b4a4f9acf70751f6e20fb62920f456ab82294a03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Thu, 19 May 2022 08:14:50 +0200 Subject: [PATCH 34/54] cleanup --- yt/visualization/_handlers.py | 8 +++----- yt/visualization/plot_container.py | 1 - 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/yt/visualization/_handlers.py b/yt/visualization/_handlers.py index 42d63463fc..d6079a7a01 100644 --- a/yt/visualization/_handlers.py +++ b/yt/visualization/_handlers.py @@ -194,13 +194,11 @@ def dynamic_range(self, newval: Optional[float]) -> None: newval = float(newval) except TypeError: raise TypeError( - "Expected a float. " f"Received {newval} with type {type(newval)}" + f"Expected a float. Received {newval} with type {type(newval)}" ) from None - if newval <= 0: - raise ValueError( - f"Dynamic range must be strictly positive. Received {newval}" - ) + if newval == 0: + raise ValueError("Dynamic range cannot be zero.") if newval == 1: raise ValueError("Dynamic range cannot be unity.") diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index 41caecff26..23dba57430 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -552,7 +552,6 @@ def save( if "Cutting" in self.data_source.__class__.__name__: plot_type = "OffAxisSlice" - # somehow mypy thinks we may not have a plots attr yet, hence we turn it off here for k, v in self.plots.items(): if isinstance(k, tuple): k = k[1] From 8343dabe41957c239aee0725d38c614854e2cc19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Thu, 19 May 2022 11:05:35 +0200 Subject: [PATCH 35/54] DEPR: deprecate passing zmin=None explicitly, add a test to check that it still works and raises a deprecation warning --- nose_unit.cfg | 2 +- tests/tests.yaml | 1 + yt/visualization/plot_container.py | 37 +++++++++++++++++++++---- yt/visualization/tests/test_set_zlim.py | 28 +++++++++++++++++-- 4 files changed, 59 insertions(+), 9 deletions(-) diff --git a/nose_unit.cfg b/nose_unit.cfg index 8256e7baf4..dd30458291 100644 --- a/nose_unit.cfg +++ b/nose_unit.cfg @@ -6,5 +6,5 @@ nologcapture=1 verbosity=2 where=yt with-timer=1 -ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py|test_version\.py) +ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py|test_version\.py|test_set_zlim\.py) exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF diff --git a/tests/tests.yaml b/tests/tests.yaml index f2f81ef4a6..2e854d06e0 100644 --- a/tests/tests.yaml +++ b/tests/tests.yaml @@ -214,6 +214,7 @@ other_tests: - "--ignore-files=test_normal_plot_api\\.py" - "--ignore-file=test_file_sanitizer\\.py" - "--ignore-files=test_version\\.py" + - "--ignore-files=test_set_zlim\\.py" - "--exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF" - "--exclude-test=yt.frontends.adaptahop.tests.test_outputs" - "--exclude-test=yt.frontends.stream.tests.test_stream_particles.test_stream_non_cartesian_particles" diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index 23dba57430..6d12ab0920 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -953,8 +953,8 @@ def set_background_color(self, field, color=None): def set_zlim( self, field, - zmin: Union[float, Quantity, Literal[None, "min"], Unset] = UNSET, - zmax: Union[float, Quantity, Literal[None, "max"], Unset] = UNSET, + zmin: Union[float, Quantity, Literal["min"], Unset] = UNSET, + zmax: Union[float, Quantity, Literal["max"], Unset] = UNSET, dynamic_range: Optional[float] = None, ): """set the scale of the colormap @@ -964,11 +964,11 @@ def set_zlim( field : string the field to set a colormap scale if field == 'all', applies to all plots. - zmin : float, Quantity, None or 'min' - the new minimum of the colormap scale. If None or 'min', will + zmin : float, Quantity, or 'min' + the new minimum of the colormap scale. If 'min', will set to the minimum value in the current view. - zmax : float, Quantity, None or 'max' - the new maximum of the colormap scale. If None or 'max', will + zmax : float, Quantity, or 'max' + the new maximum of the colormap scale. If 'max', will set to the maximum value in the current view. Other Parameters @@ -984,9 +984,34 @@ def set_zlim( if zmin is UNSET: zmin = None + elif zmin is None: + # this sentinel value juggling is barely maintainable + # this use case is deprecated so we can simplify the logic here + # in the future and use `None` as the default value, + # instead of the custom sentinel UNSET + issue_deprecation_warning( + "Passing `zmin=None` explicitly is deprecated. " + "If you wish to explicitly set zmin to the minimal " + "data value, pass `zmin='min'` instead. " + "Otherwise leave this argument unset.", + since="4.1.0", + stacklevel=5, + ) + zmin = "min" if zmax is UNSET: zmax = None + elif zmax is None: + # see above + issue_deprecation_warning( + "Passing `zmax=None` explicitly is deprecated. " + "If you wish to explicitly set zmax to the maximal " + "data value, pass `zmin='max'` instead. " + "Otherwise leave this argument unset.", + since="4.1.0", + stacklevel=5, + ) + zmax = "max" pnh = self.plots[field].norm_handler pnh.vmin = zmin diff --git a/yt/visualization/tests/test_set_zlim.py b/yt/visualization/tests/test_set_zlim.py index 4b3d69b491..00d41c0211 100644 --- a/yt/visualization/tests/test_set_zlim.py +++ b/yt/visualization/tests/test_set_zlim.py @@ -1,6 +1,8 @@ import numpy as np import numpy.testing as npt +import pytest +from yt._maintenance.deprecation import VisibleDeprecationWarning from yt.testing import fake_amr_ds from yt.visualization.api import SlicePlot @@ -66,8 +68,8 @@ def test_reset_zlim(): delta = np.diff(raw_lims)[0] p.set_zlim(field, zmin=raw_lims[0] + delta / 2) - # passing a None explicitly should restore default limit - p.set_zlim(field, zmin=None) + # passing "min" should restore default limit + p.set_zlim(field, zmin="min") p._setup_plots() cb = p.plots[field].image.colorbar @@ -124,3 +126,25 @@ def test_set_dynamic_range_with_min(): cb = p.plots[field].image.colorbar new_lims = np.array((cb.vmin, cb.vmax)) npt.assert_almost_equal(new_lims, (vmin, 2 * vmin)) + + +def test_set_dynamic_range_with_None(): + field = ("gas", "density") + ds = fake_amr_ds(fields=[field], units=["g/cm**3"]) + + p = SlicePlot(ds, "x", field) + p.set_buff_size(16) + + p._setup_plots() + cb = p.plots[field].image.colorbar + vmin = cb.vmin + + with pytest.raises( + VisibleDeprecationWarning, match="Passing `zmin=None` explicitly is deprecated" + ): + p.set_zlim(field, zmin=None, dynamic_range=2) + + p._setup_plots() + cb = p.plots[field].image.colorbar + new_lims = np.array((cb.vmin, cb.vmax)) + npt.assert_almost_equal(new_lims, (vmin, 2 * vmin)) From 2dbd3422eaf8d8e205cd9ebe8c4595974f46fb4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Thu, 19 May 2022 11:29:56 +0200 Subject: [PATCH 36/54] DOC: update colorbar norms docs --- doc/source/visualizing/plots.rst | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/doc/source/visualizing/plots.rst b/doc/source/visualizing/plots.rst index b43c778e74..ba6bb373a6 100644 --- a/doc/source/visualizing/plots.rst +++ b/doc/source/visualizing/plots.rst @@ -931,6 +931,8 @@ strictly positive, and symlog otherwise. yt supports two different interfaces to move away from the defaults. See **constrained norms** and **arbitrary norm** hereafter. +.. note:: defaults can be configured on a per-field basis, see :ref:`per-field-plotconfig` + **Constrained norms** The norm properties can be constrained via two methods @@ -957,8 +959,11 @@ Units can be left out, in which case they implicitly match the current display units of the colorbar (controlled with the ``set_unit`` method, see :ref:`_set-image-units`). -Both ``zmin`` and ``zmax`` are optional, but note that they respectively default -to 0 and 1, which can be widely inappropriate, so it is recommended to specify both. +It is not required to specify both ``zmin`` and ``zmax``. Left unset, they will +default to extremal values in the current view. This default beheviour can be +enforced or restored by passing ``zmin="min"`` (reps. ``zmax="max"``) +explicitly. + :meth:`~yt.visualization.plot_container.ImagePlotContainer.set_log` takes a boolean argument to select log (``True``) or linear (``False``) scalings. From 24c84433fdcc168cc6f942414501d24ada9ff106 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Thu, 19 May 2022 11:48:30 +0200 Subject: [PATCH 37/54] TYP: renounce typing constraints for a very flexible attribute --- yt/visualization/_handlers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/yt/visualization/_handlers.py b/yt/visualization/_handlers.py index d6079a7a01..d9e5e7241e 100644 --- a/yt/visualization/_handlers.py +++ b/yt/visualization/_handlers.py @@ -437,11 +437,14 @@ def cmap(self, newval) -> None: ) @property - def background_color(self) -> str: + def background_color(self) -> Any: return self._background_color or "white" @background_color.setter - def background_color(self, newval): + def background_color(self, newval: Any): + # not attempting to constrain types here because + # down the line it really depends on matplotlib.axes.Axes.set_faceolor + # which is very type-flexibile if newval is None: self._background_color = self.cmap(0) else: From b78b055beb6a25ed995f49c20d0f6a7fe903a7c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Thu, 19 May 2022 12:26:29 +0200 Subject: [PATCH 38/54] TST: simplify some flaky tests --- tests/tests.yaml | 2 -- .../tests/test_norm_api_custom_norm.py | 15 ++------ .../tests/test_norm_api_set_unit_and_zlim.py | 36 ++++++++++--------- .../tests/test_norm_api_set_zlim_and_unit.py | 22 ------------ 4 files changed, 23 insertions(+), 52 deletions(-) delete mode 100644 yt/visualization/tests/test_norm_api_set_zlim_and_unit.py diff --git a/tests/tests.yaml b/tests/tests.yaml index 2e854d06e0..f2e4206e18 100644 --- a/tests/tests.yaml +++ b/tests/tests.yaml @@ -187,8 +187,6 @@ answer_tests: - yt/visualization/tests/test_norm_api_lineplot.py:test_lineplot_set_axis_properties - yt/visualization/tests/test_norm_api_profileplot.py:test_profileplot_set_axis_properties - yt/visualization/tests/test_norm_api_custom_norm.py:test_sliceplot_custom_norm - - yt/visualization/tests/test_norm_api_set_zlim_and_unit.py:test_sliceplot_set_zlim_and_unit - - yt/visualization/tests/test_norm_api_set_unit_and_zlim.py:test_sliceplot_set_unit_and_zlim - yt/visualization/tests/test_norm_api_set_background_color.py:test_sliceplot_set_background_color - yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_implicit.py:test_phaseplot_set_colorbar_properties_implicit - yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_explicit.py:test_phaseplot_set_colorbar_properties_explicit diff --git a/yt/visualization/tests/test_norm_api_custom_norm.py b/yt/visualization/tests/test_norm_api_custom_norm.py index 67ccca6f1a..fc4badebd0 100644 --- a/yt/visualization/tests/test_norm_api_custom_norm.py +++ b/yt/visualization/tests/test_norm_api_custom_norm.py @@ -1,5 +1,4 @@ import matplotlib -from matplotlib.colors import LogNorm, Normalize, SymLogNorm from nose.plugins.attrib import attr from packaging.version import Version @@ -20,21 +19,13 @@ def test_sliceplot_custom_norm(): # don't import this at top level because it's only available since MPL 3.2 from matplotlib.colors import TwoSlopeNorm - norms_to_test = [ - (Normalize(), "linear"), - (LogNorm(), "log"), - (TwoSlopeNorm(vcenter=0, vmin=-0.5, vmax=1), "twoslope"), - (SymLogNorm(linthresh=0.01, vmin=-1, vmax=1), "symlog"), - ] - ds = fake_random_ds(16) def create_image(filename_prefix): field = ("gas", "density") - for norm, name in norms_to_test: - p = SlicePlot(ds, "z", field) - p.set_norm(field, norm=norm) - p.save(f"{filename_prefix}_{name}") + p = SlicePlot(ds, "z", field) + p.set_norm(field, norm=(TwoSlopeNorm(vcenter=0, vmin=-0.5, vmax=1))) + p.save(f"{filename_prefix}") test = GenericImageTest(ds, create_image, 12) test.prefix = "test_sliceplot_custom_norm" diff --git a/yt/visualization/tests/test_norm_api_set_unit_and_zlim.py b/yt/visualization/tests/test_norm_api_set_unit_and_zlim.py index 1e78fe75be..250e0f05b4 100644 --- a/yt/visualization/tests/test_norm_api_set_unit_and_zlim.py +++ b/yt/visualization/tests/test_norm_api_set_unit_and_zlim.py @@ -1,22 +1,26 @@ -from nose.plugins.attrib import attr +import numpy.testing as npt -from yt.testing import ANSWER_TEST_TAG, fake_random_ds -from yt.utilities.answer_testing.framework import GenericImageTest +from yt.testing import fake_random_ds from yt.visualization.api import SlicePlot -@attr(ANSWER_TEST_TAG) -def test_sliceplot_set_unit_and_zlim(): +def test_sliceplot_set_unit_and_zlim_order(): ds = fake_random_ds(16) + field = ("gas", "density") - def create_image(filename_prefix): - field = ("gas", "density") - p = SlicePlot(ds, "z", field) - p.set_unit(field, "kg/m**3") - p.set_zlim(field, zmin=0) - p.save(filename_prefix) - - test = GenericImageTest(ds, create_image, 12) - test.prefix = "test_sliceplot_set_unit_and_zlim" - test.answer_name = "sliceplot_set_unit_and_zlim" - yield test + p0 = SlicePlot(ds, "z", field) + p0.set_unit(field, "kg/m**3") + p0.set_zlim(field, zmin=0) + + # reversing order of operations + p1 = SlicePlot(ds, "z", field) + p1.set_zlim(field, zmin=0) + p1.set_unit(field, "kg/m**3") + + p0._setup_plots() + p1._setup_plots() + + im0 = p0.plots[field].image._A + im1 = p1.plots[field].image._A + + npt.assert_allclose(im0, im1) diff --git a/yt/visualization/tests/test_norm_api_set_zlim_and_unit.py b/yt/visualization/tests/test_norm_api_set_zlim_and_unit.py deleted file mode 100644 index 41935974b1..0000000000 --- a/yt/visualization/tests/test_norm_api_set_zlim_and_unit.py +++ /dev/null @@ -1,22 +0,0 @@ -from nose.plugins.attrib import attr - -from yt.testing import ANSWER_TEST_TAG, fake_random_ds -from yt.utilities.answer_testing.framework import GenericImageTest -from yt.visualization.api import SlicePlot - -ds = fake_random_ds(16) - - -@attr(ANSWER_TEST_TAG) -def test_sliceplot_set_zlim_and_unit(): - def create_image(filename_prefix): - field = ("gas", "density") - p = SlicePlot(ds, "z", field) - p.set_zlim(field, zmin=0) - p.set_unit(field, "kg/m**3") - p.save(filename_prefix) - - test = GenericImageTest(ds, create_image, 12) - test.prefix = "test_sliceplot_set_zlim_and_unit" - test.answer_name = "sliceplot_set_zlim_and_unit" - yield test From 26be71c4407323c5df8a06bdb8186f67cb057f9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Thu, 19 May 2022 13:01:28 +0200 Subject: [PATCH 39/54] bump answers store --- answer-store | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/answer-store b/answer-store index d9c1557f7d..209065ac2e 160000 --- a/answer-store +++ b/answer-store @@ -1 +1 @@ -Subproject commit d9c1557f7dba849a4446c11e2239b1822d1b7fde +Subproject commit 209065ac2e6ff1d11dd108c8cd268363caa9cc9f From a86464ac25b82f9f6000dad6c5b0e222c40f8988 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Thu, 19 May 2022 13:43:21 +0200 Subject: [PATCH 40/54] TST: fix broken answer tests (maybe bumping will still be required) --- yt/visualization/tests/test_particle_plot.py | 2 +- yt/visualization/tests/test_plotwindow.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/yt/visualization/tests/test_particle_plot.py b/yt/visualization/tests/test_particle_plot.py index d5e060cb52..d0037e48fd 100644 --- a/yt/visualization/tests/test_particle_plot.py +++ b/yt/visualization/tests/test_particle_plot.py @@ -42,7 +42,7 @@ def setup(): PROJ_ATTR_ARGS["set_log"] = [((("all", "particle_mass"), False), {})] PROJ_ATTR_ARGS["set_zlim"] = [ ((("all", "particle_mass"), 1e39, 1e42), {}), - ((("all", "particle_mass"), 1e39, None), {"dynamic_range": 4}), + ((("all", "particle_mass"),), {"zmin": 1e39, "dynamic_range": 4}), ] PHASE_ATTR_ARGS = { diff --git a/yt/visualization/tests/test_plotwindow.py b/yt/visualization/tests/test_plotwindow.py index bf19692c86..61a4539507 100644 --- a/yt/visualization/tests/test_plotwindow.py +++ b/yt/visualization/tests/test_plotwindow.py @@ -72,7 +72,7 @@ def setup(): "set_figure_size": [((7.0,), {})], "set_zlim": [ (("density", 1e-25, 1e-23), {}), - (("density", 1e-25, None), {"dynamic_range": 4}), + (("density",), {"zmin": 1e-25, "dynamic_range": 4}), ], "zoom": [((10,), {})], "toggle_right_handed": [((), {})], From d9a8efa5f744298677ac0c53fb28674018cec801 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Thu, 19 May 2022 15:01:35 +0200 Subject: [PATCH 41/54] bump answers again --- tests/tests.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests.yaml b/tests/tests.yaml index f2e4206e18..155271acda 100644 --- a/tests/tests.yaml +++ b/tests/tests.yaml @@ -95,7 +95,7 @@ answer_tests: - yt/frontends/owls/tests/test_outputs.py:test_snapshot_033 - yt/frontends/owls/tests/test_outputs.py:test_OWLS_particlefilter - local_pw_045: # PR 3849 + local_pw_046: # PR 3849 - yt/visualization/tests/test_plotwindow.py:test_attributes - yt/visualization/tests/test_particle_plot.py:test_particle_projection_answers - yt/visualization/tests/test_particle_plot.py:test_particle_projection_filter @@ -183,7 +183,7 @@ answer_tests: local_nc4_cm1_001: # PR 2176 - yt/frontends/nc4_cm1/tests/test_outputs.py:test_cm1_mesh_fields - local_norm_api_007: # PR 3849 + local_norm_api_008: # PR 3849 - yt/visualization/tests/test_norm_api_lineplot.py:test_lineplot_set_axis_properties - yt/visualization/tests/test_norm_api_profileplot.py:test_profileplot_set_axis_properties - yt/visualization/tests/test_norm_api_custom_norm.py:test_sliceplot_custom_norm From 50a8959135472e7a27764d320f5f43691bade49d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Thu, 19 May 2022 18:09:21 +0200 Subject: [PATCH 42/54] cleanup test --- yt/visualization/tests/test_norm_api_custom_norm.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/yt/visualization/tests/test_norm_api_custom_norm.py b/yt/visualization/tests/test_norm_api_custom_norm.py index fc4badebd0..14e28581a9 100644 --- a/yt/visualization/tests/test_norm_api_custom_norm.py +++ b/yt/visualization/tests/test_norm_api_custom_norm.py @@ -2,21 +2,19 @@ from nose.plugins.attrib import attr from packaging.version import Version -from yt.testing import ANSWER_TEST_TAG, fake_random_ds, skip_case +from yt.testing import ANSWER_TEST_TAG, fake_random_ds, skipif from yt.utilities.answer_testing.framework import GenericImageTest from yt.visualization.api import SlicePlot MPL_VERSION = Version(matplotlib.__version__) +@skipif( + MPL_VERSION < Version("3.2"), + reason=f"TwoSlopeNorm requires MPL 3.2, we have {MPL_VERSION}", +) @attr(ANSWER_TEST_TAG) def test_sliceplot_custom_norm(): - if MPL_VERSION < Version("3.4"): - skip_case( - reason="in MPL<3.4, SymLogNorm emits a deprecation warning " - "that cannot be easily filtered" - ) - # don't import this at top level because it's only available since MPL 3.2 from matplotlib.colors import TwoSlopeNorm ds = fake_random_ds(16) From 9065f04ba6b40d7281599b6e887337bbfe5ad627 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Thu, 19 May 2022 20:36:44 +0200 Subject: [PATCH 43/54] DEPR: add missing stacklevel argument to deprecation warning --- yt/visualization/plot_container.py | 1 + 1 file changed, 1 insertion(+) diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index 6d12ab0920..2a8b965310 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -178,6 +178,7 @@ def set_log( issue_deprecation_warning( "the symlog_auto argument is deprecated. Use linthresh='auto' instead", since="4.1", + stacklevel=5, ) if symlog_auto is True: linthresh = "auto" From c4077bbc1e8dee0f64d9438a22d20348fc7eb4e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Thu, 19 May 2022 20:36:55 +0200 Subject: [PATCH 44/54] UX: add a UserWarning --- yt/visualization/plot_container.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index 2a8b965310..bb77d21896 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -190,6 +190,13 @@ def set_log( f"Expected a boolean, got {symlog_auto!r}" ) + if log is not None and linthresh is not None: + # we do not raise an error here for backward compatibility + warnings.warn( + f"Passing log={log} has no effect when linthresh is also specified.", + stacklevel=4, + ) + pnh = self.plots[field].norm_handler if linthresh is not None: From a70c455111ccfb8bf7563882acb7c04387d1788a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Thu, 19 May 2022 22:38:00 +0200 Subject: [PATCH 45/54] TST: update a test to use newly encouraged API and avoid a UserWarning --- yt/visualization/tests/test_plotwindow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yt/visualization/tests/test_plotwindow.py b/yt/visualization/tests/test_plotwindow.py index 61a4539507..0d3d6d1929 100644 --- a/yt/visualization/tests/test_plotwindow.py +++ b/yt/visualization/tests/test_plotwindow.py @@ -737,7 +737,7 @@ def _neg_density(field, data): ("gas", "negative_density"), ]: plot = SlicePlot(ds, 2, field) - plot.set_log(field, True, linthresh=0.1) + plot.set_log(field, linthresh=0.1) with tempfile.NamedTemporaryFile(suffix="png") as f: plot.save(f.name) From 1a20ce322619434a37fb842591db4099ba8931ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 8 Jun 2022 00:07:22 +0200 Subject: [PATCH 46/54] BUG: fix automated field unit detection for profiles --- yt/data_objects/selection_objects/data_selection_objects.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/yt/data_objects/selection_objects/data_selection_objects.py b/yt/data_objects/selection_objects/data_selection_objects.py index a4cc4fe612..0a4081882e 100644 --- a/yt/data_objects/selection_objects/data_selection_objects.py +++ b/yt/data_objects/selection_objects/data_selection_objects.py @@ -262,6 +262,8 @@ def _generate_fields(self, fields_to_generate): fi.units = sunits fi.dimensions = dimensions self.field_data[field] = self.ds.arr(fd, units) + if fi.output_units is None: + fi.output_units = fi.units try: fd.convert_to_units(fi.units) From f81a6ef543c04e79cee0e0f8e14f73274f5965b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Fri, 10 Jun 2022 21:39:18 +0200 Subject: [PATCH 47/54] BUG: fix erroneous names breaking p.set_colorbar_minorticks methods --- yt/visualization/_handlers.py | 2 +- yt/visualization/plot_container.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/yt/visualization/_handlers.py b/yt/visualization/_handlers.py index d9e5e7241e..3d09f6dc33 100644 --- a/yt/visualization/_handlers.py +++ b/yt/visualization/_handlers.py @@ -415,7 +415,7 @@ def draw_minorticks(self, newval) -> None: raise TypeError( f"Excpected a boolean, got {newval} with type {type(newval)}" ) - self._draw_minoticks = newval + self._draw_minorticks = newval @property def cmap(self) -> Colormap: diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index bb77d21896..4dba9aee03 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -1044,7 +1044,7 @@ def set_colorbar_minorticks(self, field, state): state : bool the state indicating 'on' (True) or 'off' (False) """ - self.plots[field].colormap_handler.draw_minorticks = state + self.plots[field].colorbar_handler.draw_minorticks = state return self @invalidate_plot From bcdbdbefd7c39009e8265cd28e670cecc41bbf1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Fri, 17 Jun 2022 19:43:25 +0200 Subject: [PATCH 48/54] BUG: cleanup merge conflict leftover --- yt/visualization/plot_window.py | 1 - 1 file changed, 1 deletion(-) diff --git a/yt/visualization/plot_window.py b/yt/visualization/plot_window.py index 21f51ed31f..7c3a759c8e 100644 --- a/yt/visualization/plot_window.py +++ b/yt/visualization/plot_window.py @@ -240,7 +240,6 @@ def __init__( self._projection = get_mpl_transform(projection) self._transform = get_mpl_transform(transform) - self.setup_callbacks() self._setup_plots() for field in self.data_source._determine_fields(self.fields): From a72c733e9e2f1f10ea3f7db72bd4e085ee28e07f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Fri, 24 Jun 2022 17:14:50 +0200 Subject: [PATCH 49/54] cleanup redundant computation Co-authored-by: Chris Havlin --- yt/visualization/_handlers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/yt/visualization/_handlers.py b/yt/visualization/_handlers.py index 3d09f6dc33..8ec3d1ec77 100644 --- a/yt/visualization/_handlers.py +++ b/yt/visualization/_handlers.py @@ -321,7 +321,6 @@ def get_norm(self, data: np.ndarray, *args, **kw) -> Normalize: dvmax = 1 * getattr(data, "units", 1) kw.setdefault("vmax", dvmax) - min_abs_val, max_abs_val = np.sort(np.abs((kw["vmin"], kw["vmax"]))) if self.norm_type is not None: # this is a convenience mechanism for backward compat, # allowing to toggle between lin and log scaling without detailed user input From e7b17cdc7e11312d38c77039f961340f7ad73506 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Fri, 5 Aug 2022 13:48:16 +0200 Subject: [PATCH 50/54] cleanup unused imports (cleanup merge conflict resolution from main) --- yt/visualization/plot_window.py | 1 - yt/visualization/profile_plotter.py | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/yt/visualization/plot_window.py b/yt/visualization/plot_window.py index 5f056eace1..95b7c0068b 100644 --- a/yt/visualization/plot_window.py +++ b/yt/visualization/plot_window.py @@ -7,7 +7,6 @@ import numpy as np from matplotlib.colors import Normalize from more_itertools import always_iterable -from packaging.version import Version from unyt.exceptions import UnitConversionError from yt._maintenance.deprecation import issue_deprecation_warning diff --git a/yt/visualization/profile_plotter.py b/yt/visualization/profile_plotter.py index 7120f3694f..17855e23d6 100644 --- a/yt/visualization/profile_plotter.py +++ b/yt/visualization/profile_plotter.py @@ -14,13 +14,10 @@ from yt.funcs import iter_fields, matplotlib_style_context from yt.utilities.exceptions import YTNotInsideNotebook from yt.visualization._handlers import ColorbarHandler, NormHandler -from yt.visualization.base_plot_types import PlotMPL +from yt.visualization.base_plot_types import ImagePlotMPL, PlotMPL from ..data_objects.selection_objects.data_selection_objects import YTSelectionContainer from ._commons import validate_image_name -from .base_plot_types import ImagePlotMPL -from ._commons import DEFAULT_FONT_PROPERTIES, MPL_VERSION, validate_image_name -from .base_plot_types import ImagePlotMPL, PlotMPL from .plot_container import ( BaseLinePlot, ImagePlotContainer, From f3d82d7180969e3ad212feaea9a6d2bbc37a6911 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sun, 21 Aug 2022 08:30:52 +0200 Subject: [PATCH 51/54] Apply suggestions from code review Co-authored-by: Cameron Hummels --- doc/source/visualizing/plots.rst | 9 +++++---- yt/visualization/plot_container.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/doc/source/visualizing/plots.rst b/doc/source/visualizing/plots.rst index 65af1d0d4d..e782faa764 100644 --- a/doc/source/visualizing/plots.rst +++ b/doc/source/visualizing/plots.rst @@ -945,10 +945,11 @@ hereafter. **Constrained norms** -The norm properties can be constrained via two methods +The standard way to change colorbar scalings between linear, log, and symmetric +log (symlog). Colorbar properties can be constrained via two methods: -- :meth:`~yt.visualization.plot_container.PlotContainer.set_zlim` controls the extrema - of the value range ``zmin`` and ``zmax``. +- :meth:`~yt.visualization.plot_container.PlotContainer.set_zlim` controls the limits + of the colorbar range: ``zmin`` and ``zmax``. - :meth:`~yt.visualization.plot_container.ImagePlotContainer.set_log` allows switching to linear or symlog norms. With symlog, the linear threshold can be set explicitly. Otherwise, yt will dynamically determine a reasonable value. @@ -970,7 +971,7 @@ units of the colorbar (controlled with the ``set_unit`` method, see :ref:`_set-image-units`). It is not required to specify both ``zmin`` and ``zmax``. Left unset, they will -default to extremal values in the current view. This default beheviour can be +default to the extreme values in the current view. This default behavior can be enforced or restored by passing ``zmin="min"`` (reps. ``zmax="max"``) explicitly. diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index 9c834a884b..45010333c6 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -200,7 +200,7 @@ def set_log( if log is not None and linthresh is not None: # we do not raise an error here for backward compatibility warnings.warn( - f"Passing log={log} has no effect when linthresh is also specified.", + f"log={log} has no effect because linthresh specified. Using symlog.", stacklevel=4, ) From 49ffec3138e74cf55283a53c0a69c90e097ea8f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sun, 21 Aug 2022 09:13:13 +0200 Subject: [PATCH 52/54] address review comments --- doc/source/visualizing/plots.rst | 36 ++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/doc/source/visualizing/plots.rst b/doc/source/visualizing/plots.rst index e782faa764..2c80b4c887 100644 --- a/doc/source/visualizing/plots.rst +++ b/doc/source/visualizing/plots.rst @@ -933,13 +933,22 @@ Use any of the colormaps listed in the :ref:`colormaps` section. slc.set_cmap(("gas", "density"), "RdBu_r") slc.save() -Colorbar norms -:::::::::::::: - -Slice plots and similar plot classes default to log norms when all values are -strictly positive, and symlog otherwise. yt supports two different interfaces to -move away from the defaults. See **constrained norms** and **arbitrary norm** -hereafter. +Colorbar Normalization / Scaling +:::::::::::::::::::::::::::::::: + +For a general introduction to the topic of colorbar scaling, see +``_. Here we +will focus on the defaults, and the ways to customize them, of yt plot classes. +In this section, "norm" is used as short for "normalization", and is +interchangeable with "scaling". + +Map-like plots e.g., ``SlicePlot``, ``ProjectionPlot`` and ``PhasePlot``, +default to `logarithmic (log) +`_ +normalization when all values are strictly positive, and `symetric log (symlog) +`_ +otherwise. yt supports two different interfaces to move away from the defaults. +See **constrained norms** and **arbitrary norm** hereafter. .. note:: defaults can be configured on a per-field basis, see :ref:`per-field-plotconfig` @@ -1034,12 +1043,13 @@ can be left out in ``linthresh``. **Arbitrary norms** -Alternatively, arbitrary matplotlib norms can be passed via the -:meth:`~yt.visualization.plot_container.PlotContainer.set_norm` method. In that -case, any numeric value is treated as having implicit units, matching the -current display units. This alternative interface is more flexible, but -considered experimental as of yt 4.1. Don't forget that with great power comes -great responsibility. +Alternatively, arbitrary `matplotlib norms +`_ can be +passed via the :meth:`~yt.visualization.plot_container.PlotContainer.set_norm` +method. In that case, any numeric value is treated as having implicit units, +matching the current display units. This alternative interface is more flexible, +but considered experimental as of yt 4.1. Don't forget that with great power +comes great responsibility. .. python-script:: From 1c989ffd6f53a03f807cce89ce97fff040a23294 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Mon, 22 Aug 2022 07:36:44 +0200 Subject: [PATCH 53/54] Apply suggestions from code review Co-authored-by: Cameron Hummels --- doc/source/visualizing/plots.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/source/visualizing/plots.rst b/doc/source/visualizing/plots.rst index 2c80b4c887..80e8cb2a33 100644 --- a/doc/source/visualizing/plots.rst +++ b/doc/source/visualizing/plots.rst @@ -945,7 +945,7 @@ interchangeable with "scaling". Map-like plots e.g., ``SlicePlot``, ``ProjectionPlot`` and ``PhasePlot``, default to `logarithmic (log) `_ -normalization when all values are strictly positive, and `symetric log (symlog) +normalization when all values are strictly positive, and `symmetric log (symlog) `_ otherwise. yt supports two different interfaces to move away from the defaults. See **constrained norms** and **arbitrary norm** hereafter. @@ -958,9 +958,9 @@ The standard way to change colorbar scalings between linear, log, and symmetric log (symlog). Colorbar properties can be constrained via two methods: - :meth:`~yt.visualization.plot_container.PlotContainer.set_zlim` controls the limits - of the colorbar range: ``zmin`` and ``zmax``. + of the colorbar range: ``zmin`` and ``zmax``. - :meth:`~yt.visualization.plot_container.ImagePlotContainer.set_log` allows switching to - linear or symlog norms. With symlog, the linear threshold can be set + linear or symlog normalization. With symlog, the linear threshold can be set explicitly. Otherwise, yt will dynamically determine a reasonable value. Use the :meth:`~yt.visualization.plot_container.PlotContainer.set_zlim` From 8f65ecf25c98315fe98fc8982c46022eb3ba4741 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Mon, 22 Aug 2022 07:41:51 +0200 Subject: [PATCH 54/54] switch answer store to long-lived branch (no actual changes in answers) --- answer-store | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/answer-store b/answer-store index 209065ac2e..4b440269d2 160000 --- a/answer-store +++ b/answer-store @@ -1 +1 @@ -Subproject commit 209065ac2e6ff1d11dd108c8cd268363caa9cc9f +Subproject commit 4b440269d2e9bd0d9aaf0e8bb523dd9ded3ecafa