diff --git a/answer-store b/answer-store index 997baaad97..4b440269d2 160000 --- a/answer-store +++ b/answer-store @@ -1 +1 @@ -Subproject commit 997baaad97a69b04226e4e1a31171860eb38b491 +Subproject commit 4b440269d2e9bd0d9aaf0e8bb523dd9ded3ecafa diff --git a/conftest.py b/conftest.py index b68ec17b83..71cf5f628d 100644 --- a/conftest.py +++ b/conftest.py @@ -114,15 +114,6 @@ def pytest_configure(config): ): config.addinivalue_line("filterwarnings", value) - if MPL_VERSION < Version("3.0.0"): - config.addinivalue_line( - "filterwarnings", - ( - "ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc' " - "is deprecated since Python 3.3,and in 3.9 it will stop working:DeprecationWarning" - ), - ) - if MPL_VERSION < Version("3.5.2") and PILLOW_VERSION >= Version("9.1"): # see https://github.com/matplotlib/matplotlib/pull/22766 config.addinivalue_line( diff --git a/doc/source/visualizing/plots.rst b/doc/source/visualizing/plots.rst index 00379c8214..80e8cb2a33 100644 --- a/doc/source/visualizing/plots.rst +++ b/doc/source/visualizing/plots.rst @@ -732,6 +732,9 @@ the axes unit labels. The same result could have been accomplished by explicitly setting the ``width`` to ``(.01, 'Mpc')``. + +.. _set-image-units: + Set image units ~~~~~~~~~~~~~~~ @@ -913,7 +916,7 @@ customization. Colormaps ~~~~~~~~~ -Each of these functions accept two arguments. In all cases the first argument +Each of these functions accepts at least two arguments. In all cases the first argument is a field name. This makes it possible to use different custom colormaps for different fields tracked by the plot object. @@ -930,10 +933,38 @@ Use any of the colormaps listed in the :ref:`colormaps` section. slc.set_cmap(("gas", "density"), "RdBu_r") slc.save() -The :meth:`~yt.visualization.plot_window.AxisAlignedSlicePlot.set_log` function -accepts a field name and a boolean. If the boolean is ``True``, the colormap -for the field will be log scaled. If it is ``False`` the colormap will be -linear. +Colorbar Normalization / Scaling +:::::::::::::::::::::::::::::::: + +For a general introduction to the topic of colorbar scaling, see +``_. Here we +will focus on the defaults, and the ways to customize them, of yt plot classes. +In this section, "norm" is used as short for "normalization", and is +interchangeable with "scaling". + +Map-like plots e.g., ``SlicePlot``, ``ProjectionPlot`` and ``PhasePlot``, +default to `logarithmic (log) +`_ +normalization when all values are strictly positive, and `symmetric log (symlog) +`_ +otherwise. yt supports two different interfaces to move away from the defaults. +See **constrained norms** and **arbitrary norm** hereafter. + +.. note:: defaults can be configured on a per-field basis, see :ref:`per-field-plotconfig` + +**Constrained norms** + +The standard way to change colorbar scalings between linear, log, and symmetric +log (symlog). Colorbar properties can be constrained via two methods: + +- :meth:`~yt.visualization.plot_container.PlotContainer.set_zlim` controls the limits + of the colorbar range: ``zmin`` and ``zmax``. +- :meth:`~yt.visualization.plot_container.ImagePlotContainer.set_log` allows switching to + linear or symlog normalization. With symlog, the linear threshold can be set + explicitly. Otherwise, yt will dynamically determine a reasonable value. + +Use the :meth:`~yt.visualization.plot_container.PlotContainer.set_zlim` +method to set a custom colormap range. .. python-script:: @@ -941,43 +972,109 @@ linear. ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") slc = yt.SlicePlot(ds, "z", ("gas", "density"), width=(10, "kpc")) - slc.set_log(("gas", "density"), False) + slc.set_zlim(("gas", "density"), zmin=(1e-30, "g/cm**3"), zmax=(1e-25, "g/cm**3")) slc.save() -Specifically, a field containing both positive and negative values can be plotted -with symlog scale, by setting the boolean to be ``True`` and either providing an extra -parameter ``linthresh`` or setting ``symlog_auto = True``. In the region around zero -(when the log scale approaches to infinity), the linear scale will be applied to the -region ``(-linthresh, linthresh)`` and stretched relative to the logarithmic range. -In some cases, if yt detects zeros present in the dataset and the user has selected -``log`` scaling, yt automatically switches to ``symlog`` scaling and automatically -chooses a ``linthresh`` value to avoid errors. This is the same behavior you can -achieve by setting the keyword ``symlog_auto`` to ``True``. In these cases, yt will -choose the smallest non-zero value in a dataset to be the ``linthresh`` value. -As an example, +Units can be left out, in which case they implicitly match the current display +units of the colorbar (controlled with the ``set_unit`` method, see +:ref:`_set-image-units`). + +It is not required to specify both ``zmin`` and ``zmax``. Left unset, they will +default to the extreme values in the current view. This default behavior can be +enforced or restored by passing ``zmin="min"`` (reps. ``zmax="max"``) +explicitly. + + +:meth:`~yt.visualization.plot_container.ImagePlotContainer.set_log` takes a boolean argument +to select log (``True``) or linear (``False``) scalings. + +.. python-script:: + + import yt + + ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") + slc = yt.SlicePlot(ds, "z", ("gas", "density"), width=(10, "kpc")) + slc.set_log(("gas", "density"), False) # switch to linear scaling + slc.save() + +One can switch to `symlog +`_ +by providing a "linear threshold" (``linthresh``) value. +With ``linthresh="auto"`` yt will switch to symlog norm and guess an appropriate value +automatically. Specifically the minimum absolute value in the image is used +unless it's zero, in which case yt uses 1/1000 of the maximum value. + +.. python-script:: + + import yt + + ds = yt.load_sample("IsolatedGalaxy") + slc = yt.SlicePlot(ds, "z", ("gas", "density"), width=(10, "kpc")) + slc.set_log(("gas", "density"), linthresh="auto") + slc.save() + + +In some cases, you might find that the automatically selected linear threshold is not +really suited to your dataset, for instance .. python-script:: import yt ds = yt.load_sample("FIRE_M12i_ref11") - p = yt.ProjectionPlot(ds, "x", ("gas", "density")) - p.set_log(("gas", "density"), True, symlog_auto=True) + p = yt.ProjectionPlot(ds, "x", ("gas", "density"), width=(30, "Mpc")) + p.set_log(("gas", "density"), linthresh="auto") p.save() -Symlog is very versatile, and will work with positive or negative dataset ranges. -Here is an example using symlog scaling to plot a positive field with a linear range of -``(0, linthresh)``. +An explicit value can be passed instead .. python-script:: import yt + ds = yt.load_sample("FIRE_M12i_ref11") + p = yt.ProjectionPlot(ds, "x", ("gas", "density"), width=(30, "Mpc")) + p.set_log(("gas", "density"), linthresh=(1e-22, "g/cm**2")) + p.save() + +Similar to the ``zmin`` and ``zmax`` arguments of the ``set_zlim`` method, units +can be left out in ``linthresh``. + + +**Arbitrary norms** + +Alternatively, arbitrary `matplotlib norms +`_ can be +passed via the :meth:`~yt.visualization.plot_container.PlotContainer.set_norm` +method. In that case, any numeric value is treated as having implicit units, +matching the current display units. This alternative interface is more flexible, +but considered experimental as of yt 4.1. Don't forget that with great power +comes great responsibility. + + +.. python-script:: + + import yt + from matplotlib.colors import TwoSlopeNorm + ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") slc = yt.SlicePlot(ds, "z", ("gas", "velocity_x"), width=(30, "kpc")) - slc.set_log(("gas", "velocity_x"), True, linthresh=1.0e1) + slc.set_norm(("gas", "velocity_x"), TwoSlopeNorm(vcenter=0)) + + # using a diverging colormap to emphasize that vcenter corresponds to the + # middle value in the color range + slc.set_cmap(("gas", "velocity_x"), "RdBu") slc.save() +.. note:: When calling + :meth:`~yt.visualization.plot_container.PlotContainer.set_norm`, any constraints + previously set with + :meth:`~yt.visualization.plot_container.PlotContainer.set_log` or + :meth:`~yt.visualization.plot_container.PlotContainer.set_zlim` will be dropped. + Conversely, calling ``set_log`` or ``set_zlim`` will have the + effect of dropping any norm previously set via ``set_norm``. + + The :meth:`~yt.visualization.plot_container.ImagePlotContainer.set_background_color` function accepts a field name and a color (optional). If color is given, the function will set the plot's background color to that. If not, it will set it to the bottom @@ -994,33 +1091,6 @@ value of the color map. slc.set_background_color(("gas", "density"), color="black") slc.save("black_background") -If you would like to change the background for a plot and also hide the axes, -you will need to make use of the ``draw_frame`` keyword argument for the ``hide_axes`` function. If you do not use this keyword argument, the call to -``set_background_color`` will have no effect. Here is an example illustrating how to use the ``draw_frame`` keyword argument for ``hide_axes``: - -.. python-script:: - - import yt - - ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") - field = ("deposit", "all_density") - slc = yt.ProjectionPlot(ds, "z", field, width=(1.5, "Mpc")) - slc.set_background_color(field) - slc.hide_axes(draw_frame=True) - slc.hide_colorbar() - slc.save("just_image") - -Lastly, the :meth:`~yt.visualization.plot_window.AxisAlignedSlicePlot.set_zlim` -function makes it possible to set a custom colormap range. - -.. python-script:: - - import yt - - ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") - slc = yt.SlicePlot(ds, "z", ("gas", "density"), width=(10, "kpc")) - slc.set_zlim(("gas", "density"), 1e-30, 1e-25) - slc.save() Annotations ~~~~~~~~~~~ diff --git a/nose_unit.cfg b/nose_unit.cfg index 0b29a9dc2a..61b46a26ca 100644 --- a/nose_unit.cfg +++ b/nose_unit.cfg @@ -6,5 +6,5 @@ nologcapture=1 verbosity=2 where=yt with-timer=1 -ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py|test_version\.py|\test_on_demand_imports\.py|test_add_field\.py) +ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py|test_version\.py|\test_on_demand_imports\.py|test_set_zlim\.py|test_add_field\.py) exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF diff --git a/setup.cfg b/setup.cfg index 169a2d43f7..0314f43d3a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,7 +39,7 @@ project_urls = packages = find: install_requires = cmyt>=0.2.2 - matplotlib!=3.4.2,>=2.2.3 # keep in sync with tests/windows_conda_requirements.txt + matplotlib!=3.4.2,>=3.1 # keep in sync with tests/windows_conda_requirements.txt more-itertools>=8.4 numpy>=1.14.5 packaging>=20.9 @@ -50,6 +50,7 @@ install_requires = unyt>=2.8.0 importlib-metadata>=1.4;python_version < '3.8' tomli>=1.2.3;python_version < '3.11' + typing-extensions>=4.2.0;python_version < '3.8' python_requires = >=3.7 include_package_data = True scripts = scripts/iyt @@ -103,7 +104,7 @@ mapserver = bottle minimal = cmyt==0.2.2 - matplotlib==2.2.3 + matplotlib==3.1 more-itertools==8.4 numpy==1.14.5 pillow==6.2.0 diff --git a/tests/report_failed_answers.py b/tests/report_failed_answers.py index 8cbfa6c7f5..78a7c55d68 100644 --- a/tests/report_failed_answers.py +++ b/tests/report_failed_answers.py @@ -13,6 +13,7 @@ import os import re import shutil +import sys import tempfile import xml.etree.ElementTree as ET @@ -425,6 +426,9 @@ def handle_error(error, testcase, missing_errors, missing_answers, failed_answer + "\n" ) response = upload_answers(failed_answers) + if response is None: + log.error("Failed to upload answers for failed tests !") + sys.exit(1) if response.ok: msg += ( FLAG_EMOJI @@ -438,6 +442,9 @@ def handle_error(error, testcase, missing_errors, missing_answers, failed_answer if args.upload_missing_answers and missing_answers: response = upload_answers(missing_answers) + if response is None: + log.error("Failed to upload missing answers !") + sys.exit(1) if response.ok: msg = ( FLAG_EMOJI diff --git a/tests/tests.yaml b/tests/tests.yaml index b3a3060356..a34e00e070 100644 --- a/tests/tests.yaml +++ b/tests/tests.yaml @@ -95,7 +95,7 @@ answer_tests: - yt/frontends/owls/tests/test_outputs.py:test_snapshot_033 - yt/frontends/owls/tests/test_outputs.py:test_OWLS_particlefilter - local_pw_044: # PR 3640 + local_pw_046: # PR 3849 - yt/visualization/tests/test_plotwindow.py:test_attributes - yt/visualization/tests/test_particle_plot.py:test_particle_projection_answers - yt/visualization/tests/test_particle_plot.py:test_particle_projection_filter @@ -183,6 +183,16 @@ answer_tests: local_nc4_cm1_002: # PR 2176, 2998 - yt/frontends/nc4_cm1/tests/test_outputs.py:test_cm1_mesh_fields + local_norm_api_008: # PR 3849 + - yt/visualization/tests/test_norm_api_lineplot.py:test_lineplot_set_axis_properties + - yt/visualization/tests/test_norm_api_profileplot.py:test_profileplot_set_axis_properties + - yt/visualization/tests/test_norm_api_custom_norm.py:test_sliceplot_custom_norm + - yt/visualization/tests/test_norm_api_set_background_color.py:test_sliceplot_set_background_color + - yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_implicit.py:test_phaseplot_set_colorbar_properties_implicit + - yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_explicit.py:test_phaseplot_set_colorbar_properties_explicit + - yt/visualization/tests/test_norm_api_particleplot.py:test_particleprojectionplot_set_colorbar_properties + - yt/visualization/tests/test_norm_api_inf_zlim.py:test_inf_and_finite_values_zlim + local_cf_radial_002: # PR 1990 - yt/frontends/cf_radial/tests/test_outputs.py:test_cfradial_grid_field_values @@ -206,6 +216,7 @@ other_tests: - "--ignore-files=test_normal_plot_api\\.py" - "--ignore-file=test_file_sanitizer\\.py" - "--ignore-files=test_version\\.py" + - "--ignore-files=test_set_zlim\\.py" - "--ignore-file=test_add_field\\.py" - "--exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF" - "--exclude-test=yt.frontends.adaptahop.tests.test_outputs" diff --git a/tests/windows_conda_requirements.txt b/tests/windows_conda_requirements.txt index fbec89d138..60dfa90957 100644 --- a/tests/windows_conda_requirements.txt +++ b/tests/windows_conda_requirements.txt @@ -2,5 +2,5 @@ numpy>=1.19.4 cython>=0.29.21,<3.0 cartopy>=0.20.1 h5py~=3.1.0 -matplotlib!=3.4.2,>=2.2.3 # keep in sync with setup.cfg +matplotlib!=3.4.2,>=3.1 # keep in sync with setup.cfg scipy~=1.5.0 diff --git a/yt/_maintenance/backports.py b/yt/_maintenance/backports.py index 7e84b15419..3f4a2ac5f5 100644 --- a/yt/_maintenance/backports.py +++ b/yt/_maintenance/backports.py @@ -80,3 +80,18 @@ def __get__(self, instance, owner=None): else: pass + + +builtin_zip = zip +if sys.version_info >= (3, 10): + zip = builtin_zip +else: + # this function is deprecated in more_itertools + # because it is superseded by the standard library + from more_itertools import zip_equal + + def zip(*args, strict=False): + if strict: + return zip_equal(*args) + else: + return builtin_zip(*args) diff --git a/yt/_typing.py b/yt/_typing.py index 4b76c9899f..fd248f343d 100644 --- a/yt/_typing.py +++ b/yt/_typing.py @@ -1,5 +1,6 @@ from typing import List, Optional, Tuple, Union +import unyt as un from numpy import ndarray FieldDescT = Tuple[str, Tuple[str, List[str], Optional[str]]] @@ -10,3 +11,10 @@ Tuple[ndarray, ndarray, ndarray], # xyz Union[float, ndarray], # hsml ] + + +# types that can be converted to un.Unit +Unit = Union[un.Unit, str] + +# types that can be converted to un.unyt_quantity +Quantity = Union[un.unyt_quantity, Tuple[float, Unit]] diff --git a/yt/funcs.py b/yt/funcs.py index 0f97e96339..3fabae7681 100644 --- a/yt/funcs.py +++ b/yt/funcs.py @@ -1260,6 +1260,12 @@ def dictWithFactory(factory: Callable[[Any], Any]) -> Type: A class to create new dictionaries handling missing keys. """ + issue_deprecation_warning( + "yt.funcs.dictWithFactory will be removed in a future version of yt, please do not rely on it. " + "If you need it, copy paste this function from yt's source code", + since="4.1", + ) + class DictWithFactory(dict): def __init__(self, *args, **kwargs): self.factory = factory diff --git a/yt/utilities/answer_testing/framework.py b/yt/utilities/answer_testing/framework.py index c66c5a74a8..721846cc44 100644 --- a/yt/utilities/answer_testing/framework.py +++ b/yt/utilities/answer_testing/framework.py @@ -278,6 +278,7 @@ def get(self, ds_name, default=None): return default # Read data using shelve answer_name = f"{ds_name}" + os.makedirs(os.path.dirname(self.reference_name), exist_ok=True) ds = shelve.open(self.reference_name, protocol=-1) try: result = ds[answer_name] diff --git a/yt/utilities/exceptions.py b/yt/utilities/exceptions.py index 04d16d22e2..6533f320cc 100644 --- a/yt/utilities/exceptions.py +++ b/yt/utilities/exceptions.py @@ -902,6 +902,10 @@ def __str__(self): return msg +class YTConfigurationError(YTException): + pass + + class GenerationInProgress(Exception): def __init__(self, fields): self.fields = fields diff --git a/yt/visualization/_commons.py b/yt/visualization/_commons.py index 88c7990698..6e08d2750a 100644 --- a/yt/visualization/_commons.py +++ b/yt/visualization/_commons.py @@ -9,8 +9,17 @@ else: from importlib_metadata import version +import numpy as np +from more_itertools import always_iterable from packaging.version import Version +from yt.config import ytcfg + +if sys.version_info >= (3, 10): + pass +else: + from yt._maintenance.backports import zip + if TYPE_CHECKING: from ._mpl_imports import FigureCanvasBase @@ -215,3 +224,141 @@ def _swap_arg_pair_order(*args): new_args.append(args[x_id + 1]) new_args.append(args[x_id]) return tuple(new_args) + + +def get_log_minorticks(vmin: float, vmax: float) -> np.ndarray: + """calculate positions of linear minorticks on a log colorbar + + Parameters + ---------- + vmin : float + the minimum value in the colorbar + vmax : float + the maximum value in the colorbar + + """ + expA = np.floor(np.log10(vmin)) + expB = np.floor(np.log10(vmax)) + cofA = np.ceil(vmin / 10**expA).astype("int64") + cofB = np.floor(vmax / 10**expB).astype("int64") + lmticks = np.empty(0) + while cofA * 10**expA <= cofB * 10**expB: + if expA < expB: + lmticks = np.hstack((lmticks, np.linspace(cofA, 9, 10 - cofA) * 10**expA)) + cofA = 1 + expA += 1 + else: + lmticks = np.hstack( + (lmticks, np.linspace(cofA, cofB, cofB - cofA + 1) * 10**expA) + ) + expA += 1 + return np.array(lmticks) + + +def get_symlog_minorticks(linthresh: float, vmin: float, vmax: float) -> np.ndarray: + """calculate positions of linear minorticks on a symmetric log colorbar + + Parameters + ---------- + linthresh : float + the threshold for the linear region + vmin : float + the minimum value in the colorbar + vmax : float + the maximum value in the colorbar + + """ + if vmin > 0: + return get_log_minorticks(vmin, vmax) + elif vmax < 0 and vmin < 0: + return -get_log_minorticks(-vmax, -vmin) + elif vmin == 0: + return np.hstack((0, get_log_minorticks(linthresh, vmax))) + elif vmax == 0: + return np.hstack((-get_log_minorticks(linthresh, -vmin)[::-1], 0)) + else: + return np.hstack( + ( + -get_log_minorticks(linthresh, -vmin)[::-1], + 0, + get_log_minorticks(linthresh, vmax), + ) + ) + + +def get_symlog_majorticks(linthresh: float, vmin: float, vmax: float) -> np.ndarray: + """calculate positions of major ticks on a log colorbar + + Parameters + ---------- + linthresh : float + the threshold for the linear region + vmin : float + the minimum value in the colorbar + vmax : float + the maximum value in the colorbar + + """ + if vmin >= 0.0: + yticks = [vmin] + list( + 10 + ** np.arange( + np.rint(np.log10(linthresh)), + np.ceil(np.log10(1.1 * vmax)), + ) + ) + elif vmax <= 0.0: + if MPL_VERSION >= Version("3.5.0b"): + offset = 0 + else: + offset = 1 + + yticks = list( + -( + 10 + ** np.arange( + np.floor(np.log10(-vmin)), + np.rint(np.log10(linthresh)) - offset, + -1, + ) + ) + ) + [vmax] + else: + yticks = ( + list( + -( + 10 + ** np.arange( + np.floor(np.log10(-vmin)), + np.rint(np.log10(linthresh)) - 1, + -1, + ) + ) + ) + + [0] + + list( + 10 + ** np.arange( + np.rint(np.log10(linthresh)), + np.ceil(np.log10(1.1 * vmax)), + ) + ) + ) + if yticks[-1] > vmax: + yticks.pop() + return np.array(yticks) + + +def get_default_from_config(data_source, *, field, keys, defaults): + _keys = list(always_iterable(keys)) + _defaults = list(always_iterable(defaults)) + + ftype, fname = data_source._determine_fields(field)[0] + ret = [ + ytcfg.get_most_specific("plot", ftype, fname, key, fallback=default) + for key, default in zip(_keys, _defaults, strict=True) + ] + if len(ret) == 1: + return ret[0] + else: + return ret diff --git a/yt/visualization/_handlers.py b/yt/visualization/_handlers.py new file mode 100644 index 0000000000..8ec3d1ec77 --- /dev/null +++ b/yt/visualization/_handlers.py @@ -0,0 +1,454 @@ +import sys +import weakref +from numbers import Real +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import unyt as un +from matplotlib.cm import get_cmap +from matplotlib.colors import Colormap, LogNorm, Normalize, SymLogNorm +from packaging.version import Version + +from yt._typing import Quantity, Unit +from yt.config import ytcfg +from yt.funcs import get_brewer_cmap, is_sequence, mylog +from yt.visualization._commons import MPL_VERSION + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + + +class NormHandler: + """ + A bookkeeper class that can hold a fully defined norm object, or dynamically + build one on demand according to a set of constraints. + + If a fully defined norm object is added, any existing constraints are + dropped, and vice versa. These rules are implemented with properties and + watcher patterns. + + It also keeps track of display units so that vmin, vmax and linthresh can be + updated with implicit units. + """ + + # using slots here to minimize the risk of introducing bugs + # since attributes names are essential to this class's implementation + __slots__ = ( + "data_source", + "ds", + "_display_units", + "_vmin", + "_vmax", + "_dynamic_range", + "_norm_type", + "_linthresh", + "_norm_type", + "_norm", + ) + _constraint_attrs: List[str] = [ + "vmin", + "vmax", + "dynamic_range", + "norm_type", + "linthresh", + ] + + def __init__( + self, + data_source, + *, + display_units: un.Unit, + vmin: Optional[un.unyt_quantity] = None, + vmax: Optional[un.unyt_quantity] = None, + dynamic_range: Optional[float] = None, + norm_type: Optional[Type[Normalize]] = None, + norm: Optional[Normalize] = None, + linthresh: Optional[float] = None, + ): + self.data_source = weakref.proxy(data_source) + self.ds = data_source.ds # should already be a weakref proxy + self._display_units = display_units + + self._norm = norm + self._vmin = vmin + self._vmax = vmax + self._dynamic_range = dynamic_range + self._norm_type = norm_type + self._linthresh = linthresh + + if self.has_norm and self.has_constraints: + raise TypeError( + "NormHandler input is malformed. " + "A norm cannot be passed along other constraints." + ) + + def _get_constraints(self) -> Dict[str, Any]: + return { + attr: getattr(self, attr) + for attr in self.__class__._constraint_attrs + if getattr(self, attr) is not None + } + + @property + def has_constraints(self) -> bool: + return bool(self._get_constraints()) + + def _reset_constraints(self) -> None: + constraints = self._get_constraints() + if not constraints: + return + + msg = ", ".join([f"{name}={value}" for name, value in constraints.items()]) + mylog.warning("Dropping norm constraints (%s)", msg) + for name in constraints.keys(): + setattr(self, name, None) + + @property + def has_norm(self) -> bool: + return self._norm is not None + + def _reset_norm(self): + if not self.has_norm: + return + mylog.warning("Dropping norm (%s)", self.norm) + self._norm = None + + def to_float(self, val: un.unyt_quantity) -> float: + return float(val.to(self.display_units).d) + + def to_quan(self, val) -> un.unyt_quantity: + if isinstance(val, un.unyt_quantity): + return self.ds.quan(val) + elif ( + is_sequence(val) + and len(val) == 2 + and isinstance(val[0], Real) + and isinstance(val[1], (str, un.Unit)) + ): + return self.ds.quan(*val) + elif isinstance(val, Real): + return self.ds.quan(val, self.display_units) + else: + raise TypeError(f"Could not convert {val!r} to unyt_quantity") + + @property + def display_units(self) -> un.Unit: + return self._display_units + + @display_units.setter + def display_units(self, newval: Unit) -> None: + self._display_units = un.Unit(newval, registry=self.ds.unit_registry) + + def _set_quan_attr( + self, attr: str, newval: Optional[Union[Quantity, float]] + ) -> None: + if newval is None: + setattr(self, attr, None) + else: + try: + quan = self.to_quan(newval) + except TypeError as exc: + raise TypeError( + "Expected None, a float, or a unyt_quantity, " + f"received {newval} with type {type(newval)}" + ) from exc + else: + setattr(self, attr, quan) + + @property + def vmin(self) -> Optional[Union[un.unyt_quantity, Literal["min"]]]: + return self._vmin + + @vmin.setter + def vmin(self, newval: Optional[Union[Quantity, float, Literal["min"]]]) -> None: + self._reset_norm() + if newval == "min": + self._vmin = "min" + else: + self._set_quan_attr("_vmin", newval) + + @property + def vmax(self) -> Optional[Union[un.unyt_quantity, Literal["max"]]]: + return self._vmax + + @vmax.setter + def vmax(self, newval: Optional[Union[Quantity, float, Literal["max"]]]) -> None: + self._reset_norm() + if newval == "max": + self._vmax = "max" + else: + self._set_quan_attr("_vmax", newval) + + @property + def dynamic_range(self) -> Optional[float]: + return self._dynamic_range + + @dynamic_range.setter + def dynamic_range(self, newval: Optional[float]) -> None: + if newval is None: + return + + try: + newval = float(newval) + except TypeError: + raise TypeError( + f"Expected a float. Received {newval} with type {type(newval)}" + ) from None + + if newval == 0: + raise ValueError("Dynamic range cannot be zero.") + + if newval == 1: + raise ValueError("Dynamic range cannot be unity.") + + self._reset_norm() + self._dynamic_range = newval + + def get_dynamic_range( + self, dvmin: Optional[float], dvmax: Optional[float] + ) -> Tuple[float, float]: + if self.dynamic_range is None: + raise RuntimeError( + "Something went terribly wrong in setting up a dynamic range" + ) + + if self.vmax is None: + if self.vmin is None: + raise TypeError( + "Cannot set dynamic range with neither " + "vmin and vmax being constrained." + ) + if dvmin is None: + raise RuntimeError( + "Something went terribly wrong in setting up a dynamic range" + ) + return dvmin, dvmin * self.dynamic_range + elif self.vmin is None: + if dvmax is None: + raise RuntimeError( + "Something went terribly wrong in setting up a dynamic range" + ) + return dvmax / self.dynamic_range, dvmax + else: + raise TypeError( + "Cannot set dynamic range with both " + "vmin and vmax already constrained." + ) + + @property + def norm_type(self) -> Optional[Type[Normalize]]: + return self._norm_type + + @norm_type.setter + def norm_type(self, newval: Optional[Type[Normalize]]) -> None: + if not ( + newval is None + or (isinstance(newval, type) and issubclass(newval, Normalize)) + ): + raise TypeError( + "Expected a subclass of matplotlib.colors.Normalize, " + f"received {newval} with type {type(newval)}" + ) + self._reset_norm() + if newval is not SymLogNorm: + self.linthresh = None + self._norm_type = newval + + @property + def norm(self) -> Optional[Normalize]: + return self._norm + + @norm.setter + def norm(self, newval: Normalize) -> None: + if not isinstance(newval, Normalize): + raise TypeError( + "Expected a matplotlib.colors.Normalize object, " + f"received {newval} with type {type(newval)}" + ) + self._reset_constraints() + self._norm = newval + + @property + def linthresh(self) -> Optional[float]: + return self._linthresh + + @linthresh.setter + def linthresh(self, newval: Optional[Union[Quantity, float]]) -> None: + self._reset_norm() + self._set_quan_attr("_linthresh", newval) + if self._linthresh is not None and self._linthresh <= 0: + raise ValueError( + f"linthresh can only be set to strictly positive values, got {newval}" + ) + if newval is not None: + self.norm_type = SymLogNorm + + def get_norm(self, data: np.ndarray, *args, **kw) -> Normalize: + if self.has_norm: + return self.norm + + dvmin = dvmax = None + + finite_values_mask = np.isfinite(data) + # FUTURE: when the minimal supported version of numpy reaches 1.16 or newer, + # this complicated conditional can be simplified into + # if self.vmin not in (None, "min"): + if self.vmin is not None and not ( + isinstance(self.vmin, str) and self.vmin == "min" + ): + dvmin = self.to_float(self.vmin) + elif np.any(finite_values_mask): + dvmin = self.to_float(np.nanmin(data[finite_values_mask])) + + # FUTURE: see above + if self.vmax is not None and not ( + isinstance(self.vmax, str) and self.vmax == "max" + ): + dvmax = self.to_float(self.vmax) + elif np.any(finite_values_mask): + dvmax = self.to_float(np.nanmax(data[finite_values_mask])) + + if self.dynamic_range is not None: + dvmin, dvmax = self.get_dynamic_range(dvmin, dvmax) + + if dvmin is None: + dvmin = 1 * getattr(data, "units", 1) + kw.setdefault("vmin", dvmin) + + if dvmax is None: + dvmax = 1 * getattr(data, "units", 1) + kw.setdefault("vmax", dvmax) + + if self.norm_type is not None: + # this is a convenience mechanism for backward compat, + # allowing to toggle between lin and log scaling without detailed user input + norm_type = self.norm_type + else: + if kw["vmin"] == kw["vmax"] or not np.any(np.isfinite(data)): + norm_type = Normalize + elif kw["vmin"] <= 0: + norm_type = SymLogNorm + elif ( + Version("3.3") <= MPL_VERSION < Version("3.5") + and kw["vmin"] == 0 + and kw["vmax"] > 0 + ): + # normally, a LogNorm scaling would still be OK here because + # LogNorm will mask 0 values when calculating vmin. But + # due to a bug in matplotlib's imshow, if the data range + # spans many orders of magnitude while containing zero points + # vmin can get rescaled to 0, resulting in an error when the image + # gets drawn. So here we switch to symlog to avoid that until + # a fix is in -- see PR #3161 and linked issue. + cutoff_sigdigs = 15 + if ( + np.log10(np.nanmax(data[np.isfinite(data)])) + - np.log10(np.nanmin(data[data > 0])) + > cutoff_sigdigs + ): + norm_type = SymLogNorm + else: + norm_type = LogNorm + else: + norm_type = LogNorm + + if norm_type is SymLogNorm: + # if cblinthresh is not specified, try to come up with a reasonable default + min_abs_val, max_abs_val = np.sort( + np.abs((self.to_float(np.nanmin(data)), self.to_float(np.nanmax(data)))) + ) + if self.linthresh is not None: + linthresh = self.to_float(self.linthresh) + elif min_abs_val > 0: + linthresh = min_abs_val + else: + linthresh = max_abs_val / 1000 + kw.setdefault("linthresh", linthresh) + if MPL_VERSION >= Version("3.2"): + # note that this creates an inconsistency between mpl versions + # since the default value previous to mpl 3.4.0 is np.e + # but it is only exposed since 3.2.0 + kw.setdefault("base", 10) + + return norm_type(*args, **kw) + + +class ColorbarHandler: + __slots__ = ("_draw_cbar", "_draw_minorticks", "_cmap", "_background_color") + + def __init__( + self, + *, + draw_cbar: bool = True, + draw_minorticks: bool = True, + cmap: Optional[Union[Colormap, str]] = None, + background_color: Optional[str] = None, + ): + self._draw_cbar = draw_cbar + self._draw_minorticks = draw_minorticks + self._cmap: Optional[Colormap] = None + self.cmap = cmap + self._background_color = background_color + + @property + def draw_cbar(self) -> bool: + return self._draw_cbar + + @draw_cbar.setter + def draw_cbar(self, newval) -> None: + if not isinstance(newval, bool): + raise TypeError( + f"Excpected a boolean, got {newval} with type {type(newval)}" + ) + self._draw_cbar = newval + + @property + def draw_minorticks(self) -> bool: + return self._draw_minorticks + + @draw_minorticks.setter + def draw_minorticks(self, newval) -> None: + if not isinstance(newval, bool): + raise TypeError( + f"Excpected a boolean, got {newval} with type {type(newval)}" + ) + self._draw_minorticks = newval + + @property + def cmap(self) -> Colormap: + return self._cmap or get_cmap(ytcfg.get("yt", "default_colormap")) + + @cmap.setter + def cmap(self, newval) -> None: + if isinstance(newval, Colormap) or newval is None: + self._cmap = newval + elif isinstance(newval, str): + self._cmap = get_cmap(newval) + elif is_sequence(newval): + # tuple colormaps are from palettable (or brewer2mpl) + self._cmap = get_brewer_cmap(newval) + else: + raise TypeError( + "Expected a colormap object or name, " + f"got {newval} with type {type(newval)}" + ) + + @property + def background_color(self) -> Any: + return self._background_color or "white" + + @background_color.setter + def background_color(self, newval: Any): + # not attempting to constrain types here because + # down the line it really depends on matplotlib.axes.Axes.set_faceolor + # which is very type-flexibile + if newval is None: + self._background_color = self.cmap(0) + else: + self._background_color = newval + + @property + def has_background_color(self) -> bool: + return self._background_color is not None diff --git a/yt/visualization/base_plot_types.py b/yt/visualization/base_plot_types.py index f0f45bc26f..027f10b5ab 100644 --- a/yt/visualization/base_plot_types.py +++ b/yt/visualization/base_plot_types.py @@ -1,20 +1,27 @@ import sys import warnings +from abc import ABC from io import BytesIO +from typing import Optional, Tuple, Union import matplotlib import numpy as np +from matplotlib.axis import Axis +from matplotlib.colors import LogNorm, Normalize, SymLogNorm +from matplotlib.figure import Figure +from matplotlib.ticker import LogFormatterMathtext from packaging.version import Version -from yt.funcs import ( - get_brewer_cmap, - get_interactivity, - is_sequence, - matplotlib_style_context, - mylog, -) +from yt.funcs import get_interactivity, is_sequence, matplotlib_style_context, mylog +from yt.visualization._handlers import ColorbarHandler, NormHandler -from ._commons import MPL_VERSION, get_canvas, validate_image_name +from ._commons import ( + MPL_VERSION, + get_canvas, + get_symlog_majorticks, + get_symlog_minorticks, + validate_image_name, +) BACKEND_SPECS = { "GTK": ["backend_gtk", "FigureCanvasGTK", "FigureManagerGTK"], @@ -78,7 +85,15 @@ def __init__(self, viewer, window_plot, frb, field, font_properties, font_color) class PlotMPL: """A base class for all yt plots made using matplotlib, that is backend independent.""" - def __init__(self, fsize, axrect, figure, axes): + def __init__( + self, + fsize, + axrect, + *, + norm_handler: NormHandler, + figure: Optional[Figure] = None, + axes: Optional[Axis] = None, + ): """Initialize PlotMPL class""" import matplotlib.figure @@ -107,6 +122,8 @@ def __init__(self, fsize, axrect, figure, axes): which="both", axis="both", direction="in", top=True, right=True ) + self.norm_handler = norm_handler + def _create_axes(self, axrect): self.axes = self.figure.add_axes(axrect) @@ -142,9 +159,7 @@ def save(self, name, mpl_kwargs=None, canvas=None): if mpl_kwargs is None: mpl_kwargs = {} - if "papertype" not in mpl_kwargs and Version(matplotlib.__version__) < Version( - "3.3.0" - ): + if "papertype" not in mpl_kwargs and MPL_VERSION < Version("3.3.0"): mpl_kwargs["papertype"] = "auto" name = validate_image_name(name) @@ -196,13 +211,38 @@ def _repr_png_(self): return f.read() -class ImagePlotMPL(PlotMPL): +class ImagePlotMPL(PlotMPL, ABC): """A base class for yt plots made using imshow""" - def __init__(self, fsize, axrect, caxrect, zlim, figure, axes, cax): + _default_font_size = 18.0 + + def __init__( + self, + fsize=None, + axrect=None, + caxrect=None, + *, + norm_handler: NormHandler, + colorbar_handler: ColorbarHandler, + figure: Optional[Figure] = None, + axes: Optional[Axis] = None, + cax: Optional[Axis] = None, + ): """Initialize ImagePlotMPL class object""" - super().__init__(fsize, axrect, figure, axes) - self.zmin, self.zmax = zlim + self.colorbar_handler = colorbar_handler + _missing_layout_specs = [_ is None for _ in (fsize, axrect, caxrect)] + + if all(_missing_layout_specs): + fsize, axrect, caxrect = self._get_best_layout() + elif any(_missing_layout_specs): + raise TypeError( + "ImagePlotMPL cannot be initialized with partially specified layout." + ) + + super().__init__( + fsize, axrect, norm_handler=norm_handler, figure=figure, axes=axes + ) + if cax is None: self.cax = self.figure.add_axes(caxrect) else: @@ -210,12 +250,39 @@ def __init__(self, fsize, axrect, caxrect, zlim, figure, axes, cax): cax.set_position(caxrect) self.cax = cax - def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect): + def _setup_layout_constraints( + self, figure_size: Union[Tuple[float, float], float], fontsize: float + ): + # Setup base layout attributes + # derived classes need to call this before super().__init__ + # but they are free to do other stuff in between + + if isinstance(figure_size, tuple): + assert len(figure_size) == 2 + assert all(isinstance(_, float) for _ in figure_size) + self._figure_size = figure_size + else: + assert isinstance(figure_size, float) + self._figure_size = (figure_size, figure_size) + + self._draw_axes = True + fontscale = float(fontsize) / self.__class__._default_font_size + if fontscale < 1.0: + fontscale = np.sqrt(fontscale) + + self._cb_size = 0.0375 * self._figure_size[0] + self._ax_text_size = [1.2 * fontscale, 0.9 * fontscale] + self._top_buff_size = 0.30 * fontscale + self._aspect = 1.0 + + def _reset_layout(self) -> None: + size, axrect, caxrect = self._get_best_layout() + self.axes.set_position(axrect) + self.cax.set_position(caxrect) + self.figure.set_size_inches(*size) + + def _init_image(self, data, extent, aspect): """Store output of imshow in image variable""" - cbnorm_kwargs = dict( - vmin=float(self.zmin) if self.zmin is not None else None, - vmax=float(self.zmax) if self.zmax is not None else None, - ) if MPL_VERSION < Version("3.2"): # with MPL 3.1 we use np.inf as a mask instead of np.nan @@ -225,50 +292,8 @@ def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect): # see https://github.com/yt-project/yt/pull/2517 and https://github.com/yt-project/yt/pull/3793 data[~np.isfinite(data)] = np.nan - zmin = float(self.zmin) if self.zmin is not None else np.nanmin(data) - zmax = float(self.zmax) if self.zmax is not None else np.nanmax(data) - - if cbnorm == "symlog": - # if cblinthresh is not specified, try to come up with a reasonable default - min_abs_val, max_abs_val = np.sort( - np.abs((np.nanmin(data), np.nanmax(data))) - ) - if cblinthresh is not None: - if zmin * zmax > 0 and cblinthresh < min_abs_val: - # see https://github.com/yt-project/yt/issues/3564 - warnings.warn( - f"Cannot set a symlog norm with linear threshold {cblinthresh} " - f"lower than the minimal absolute data value {min_abs_val} . " - "Switching to log norm." - ) - cbnorm = "log10" - elif min_abs_val > 0: - cblinthresh = min_abs_val - else: - cblinthresh = max_abs_val / 1000 - - if cbnorm == "log10": - cbnorm_cls = matplotlib.colors.LogNorm - elif cbnorm == "linear": - cbnorm_cls = matplotlib.colors.Normalize - elif cbnorm == "symlog": - cbnorm_kwargs.update(dict(linthresh=cblinthresh)) - if MPL_VERSION >= Version("3.2.0"): - # note that this creates an inconsistency between mpl versions - # since the default value previous to mpl 3.4.0 is np.e - # but it is only exposed since 3.2.0 - cbnorm_kwargs["base"] = 10 - - cbnorm_cls = matplotlib.colors.SymLogNorm - else: - raise ValueError(f"Unknown value `cbnorm` == {cbnorm}") - - norm = cbnorm_cls(**cbnorm_kwargs) - + norm = self.norm_handler.get_norm(data) extent = [float(e) for e in extent] - # tuple colormaps are from palettable (or brewer2mpl) - if isinstance(cmap, tuple): - cmap = get_brewer_cmap(cmap) if self._transform is None: # sets the transform to be an ax.TransData object, where the @@ -286,66 +311,57 @@ def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect): extent=extent, norm=norm, aspect=aspect, - cmap=cmap, + cmap=self.colorbar_handler.cmap, interpolation="nearest", transform=transform, ) - if cbnorm == "symlog": - formatter = matplotlib.ticker.LogFormatterMathtext(linthresh=cblinthresh) - self.cb = self.figure.colorbar(self.image, self.cax, format=formatter) + self._set_axes(norm) - if zmin >= 0.0: - yticks = [zmin] + list( - 10 - ** np.arange( - np.rint(np.log10(cblinthresh)), - np.ceil(np.log10(1.1 * zmax)), - ) - ) - elif zmax <= 0.0: - if MPL_VERSION >= Version("3.5.0b"): - offset = 0 - else: - offset = 1 - - yticks = list( - -( - 10 - ** np.arange( - np.floor(np.log10(-zmin)), - np.rint(np.log10(cblinthresh)) - offset, - -1, - ) - ) - ) + [zmax] - else: - yticks = ( - list( - -( - 10 - ** np.arange( - np.floor(np.log10(-zmin)), - np.rint(np.log10(cblinthresh)) - 1, - -1, - ) - ) - ) - + [0] - + list( - 10 - ** np.arange( - np.rint(np.log10(cblinthresh)), - np.ceil(np.log10(1.1 * zmax)), - ) - ) + def _set_axes(self, norm: Normalize) -> None: + if isinstance(norm, SymLogNorm): + formatter = LogFormatterMathtext(linthresh=norm.linthresh) + self.cb = self.figure.colorbar(self.image, self.cax, format=formatter) + self.cb.set_ticks( + get_symlog_majorticks( + linthresh=norm.linthresh, vmin=norm.vmin, vmax=norm.vmax ) - if yticks[-1] > zmax: - yticks.pop() - self.cb.set_ticks(yticks) + ) else: self.cb = self.figure.colorbar(self.image, self.cax) self.cax.tick_params(which="both", axis="y", direction="in") + fmt_kwargs = dict(style="scientific", scilimits=(-2, 3), useMathText=True) + self.image.axes.ticklabel_format(**fmt_kwargs) + if type(norm) not in (LogNorm, SymLogNorm): + self.cb.ax.ticklabel_format(**fmt_kwargs) + if self.colorbar_handler.draw_minorticks: + if isinstance(norm, SymLogNorm): + if Version("3.2.0") <= MPL_VERSION < Version("3.5.0b"): + # no known working method to draw symlog minor ticks + # see https://github.com/yt-project/yt/issues/3535 + pass + else: + flinthresh = 10 ** np.floor(np.log10(norm.linthresh)) + absmax = np.abs((norm.vmin, norm.vmax)).max() + if (absmax - flinthresh) / absmax < 0.1: + flinthresh /= 10 + mticks = get_symlog_minorticks(flinthresh, norm.vmin, norm.vmax) + if MPL_VERSION < Version("3.5.0b"): + # https://github.com/matplotlib/matplotlib/issues/21258 + mticks = self.image.norm(mticks) + self.cax.yaxis.set_ticks(mticks, minor=True) + + elif isinstance(norm, LogNorm): + self.cax.minorticks_on() + self.cax.xaxis.set_visible(False) + + else: + self.cax.minorticks_on() + else: + self.cax.minorticks_off() + + self.image.axes.set_facecolor(self.colorbar_handler.background_color) + def _validate_axes_extent(self, extent, transform): # if the axes are cartopy GeoAxes, this checks that the axes extent # is properly set. @@ -374,6 +390,16 @@ def _validate_axes_extent(self, extent, transform): self.axes.set_extent(extent, crs=transform) def _get_best_layout(self): + # this method is called in ImagePlotMPL.__init__ + # required attributes + # - self._figure_size: Union[float, Tuple[float, float]] + # - self._aspect: float + # - self._ax_text_size: Tuple[float, float] + # - self._draw_axes: bool + # - self.colorbar_handler: ColorbarHandler + + # optional attribtues + # - self._unit_aspect: float # Ensure the figure size along the long axis is always equal to _figure_size unit_aspect = getattr(self, "_unit_aspect", 1) @@ -388,7 +414,7 @@ def _get_best_layout(self): else: y_fig_size /= scaling - if self._draw_colorbar: + if self.colorbar_handler.draw_cbar: cb_size = self._cb_size cb_text_size = self._ax_text_size[1] + 0.45 else: @@ -404,7 +430,7 @@ def _get_best_layout(self): top_buff_size = self._top_buff_size - if not self._draw_axes and not self._draw_colorbar: + if not self._draw_axes and not self.colorbar_handler.draw_cbar: x_axis_size = 0.0 y_axis_size = 0.0 cb_size = 0.0 @@ -454,25 +480,30 @@ def _toggle_axes(self, choice, draw_frame=None): If True, set the axes to be drawn. If False, set the axes to not be drawn. """ - if draw_frame is None: - draw_frame = choice self._draw_axes = choice self._draw_frame = draw_frame + if draw_frame is None: + draw_frame = choice + if self.colorbar_handler.has_background_color and not draw_frame: + # workaround matplotlib's behaviour + # last checked with Matplotlib 3.5 + warnings.warn( + f"Previously set background color {self.colorbar_handler.background_color} " + "has no effect. Pass `draw_axis=True` if you wish to preserve background color.", + stacklevel=4, + ) self.axes.set_frame_on(draw_frame) self.axes.get_xaxis().set_visible(choice) self.axes.get_yaxis().set_visible(choice) - size, axrect, caxrect = self._get_best_layout() - self.axes.set_position(axrect) - self.cax.set_position(caxrect) - self.figure.set_size_inches(*size) + self._reset_layout() - def _toggle_colorbar(self, choice): + def _toggle_colorbar(self, choice: bool): """ Turn on/off displaying the colorbar for a plot choice = True or False """ - self._draw_colorbar = choice + self.colorbar_handler.draw_cbar = choice self.cax.set_visible(choice) size, axrect, caxrect = self._get_best_layout() self.axes.set_position(axrect) @@ -486,7 +517,7 @@ def _get_labels(self): labels += [cbax.yaxis.label, cbax.yaxis.get_offset_text()] return labels - def hide_axes(self, draw_frame=None): + def hide_axes(self, *, draw_frame=None): """ Hide the axes for a plot including ticks and labels """ diff --git a/yt/visualization/eps_writer.py b/yt/visualization/eps_writer.py index 13381b881c..eba18bc25b 100644 --- a/yt/visualization/eps_writer.py +++ b/yt/visualization/eps_writer.py @@ -861,7 +861,7 @@ def colorbar_yt(self, plot, field=None, cb_labels=None, **kwargs): if field is not None: self.field = plot.data_source._determine_fields(field)[0] if isinstance(plot, (PlotWindow, PhasePlot)): - _cmap = plot._colormap_config[self.field] + _cmap = plot[self.field].colorbar_handler.cmap else: if plot.cmap is not None: _cmap = plot.cmap.name diff --git a/yt/visualization/line_plot.py b/yt/visualization/line_plot.py index 5df15c5d8d..2d44176138 100644 --- a/yt/visualization/line_plot.py +++ b/yt/visualization/line_plot.py @@ -1,17 +1,16 @@ from collections import defaultdict +from typing import Optional import numpy as np +from matplotlib.colors import LogNorm, Normalize, SymLogNorm from yt.funcs import is_sequence, mylog from yt.units.unit_object import Unit # type: ignore from yt.units.yt_array import YTArray -from yt.visualization.base_plot_types import PlotMPL from yt.visualization.plot_container import ( - PlotContainer, + BaseLinePlot, PlotDictionary, invalidate_plot, - linear_transform, - log_transform, ) @@ -87,10 +86,7 @@ def _sanitize_dimensions(self, item): ).dimensions if dimensions not in self.known_dimensions: self.known_dimensions[dimensions] = item - ret_item = item - else: - ret_item = self.known_dimensions[dimensions] - return ret_item + return self.known_dimensions[dimensions] def __getitem__(self, item): ret_item = self._sanitize_dimensions(item) @@ -105,7 +101,7 @@ def __contains__(self, item): return super().__contains__(ret_item) -class LinePlot(PlotContainer): +class LinePlot(BaseLinePlot): r""" A class for constructing line plots @@ -152,8 +148,12 @@ class LinePlot(PlotContainer): >>> plot.save() """ + _plot_dict_type = LinePlotDictionary _plot_type = "line_plot" + _default_figure_size = (5.0, 5.0) + _default_font_size = 14.0 + def __init__( self, ds, @@ -161,8 +161,8 @@ def __init__( start_point, end_point, npoints, - figure_size=5, - fontsize=14, + figure_size=None, + fontsize: Optional[float] = None, field_labels=None, ): """ @@ -175,25 +175,18 @@ def __init__( @classmethod def _initialize_instance( - cls, obj, ds, fields, figure_size=5, fontsize=14, field_labels=None + cls, obj, ds, fields, figure_size, fontsize, field_labels=None ): obj._x_unit = None - obj._y_units = {} obj._titles = {} data_source = ds.all_data() obj.fields = data_source._determine_fields(fields) - obj.plots = LinePlotDictionary(data_source) obj.include_legend = defaultdict(bool) - super(LinePlot, obj).__init__(data_source, figure_size, fontsize) - for f in obj.fields: - finfo = obj.data_source.ds._get_field_info(*f) - if finfo.take_log: - obj._field_transform[f] = log_transform - else: - obj._field_transform[f] = linear_transform - + super(LinePlot, obj).__init__( + data_source, figure_size=figure_size, fontsize=fontsize + ) if field_labels is None: obj.field_labels = {} else: @@ -202,9 +195,35 @@ def _initialize_instance( if f not in obj.field_labels: obj.field_labels[f] = f[1] + def _get_axrect(self): + fontscale = self._font_properties._size / self.__class__._default_font_size + top_buff_size = 0.35 * fontscale + + x_axis_size = 1.35 * fontscale + y_axis_size = 0.7 * fontscale + right_buff_size = 0.2 * fontscale + + if is_sequence(self.figure_size): + figure_size = self.figure_size + else: + figure_size = (self.figure_size, self.figure_size) + + xbins = np.array([x_axis_size, figure_size[0], right_buff_size]) + ybins = np.array([y_axis_size, figure_size[1], top_buff_size]) + + x_frac_widths = xbins / xbins.sum() + y_frac_widths = ybins / ybins.sum() + + return ( + x_frac_widths[0], + y_frac_widths[0], + x_frac_widths[1], + y_frac_widths[1], + ) + @classmethod def from_lines( - cls, ds, fields, lines, figure_size=5, font_size=14, field_labels=None + cls, ds, fields, lines, figure_size=None, font_size=None, field_labels=None ): """ A class method for constructing a line plot from multiple sampling lines @@ -252,41 +271,6 @@ def from_lines( obj._setup_plots() return obj - def _get_plot_instance(self, field): - fontscale = self._font_properties._size / 14.0 - top_buff_size = 0.35 * fontscale - - x_axis_size = 1.35 * fontscale - y_axis_size = 0.7 * fontscale - right_buff_size = 0.2 * fontscale - - if is_sequence(self.figure_size): - figure_size = self.figure_size - else: - figure_size = (self.figure_size, self.figure_size) - - xbins = np.array([x_axis_size, figure_size[0], right_buff_size]) - ybins = np.array([y_axis_size, figure_size[1], top_buff_size]) - - size = [xbins.sum(), ybins.sum()] - - x_frac_widths = xbins / size[0] - y_frac_widths = ybins / size[1] - - axrect = ( - x_frac_widths[0], - y_frac_widths[0], - x_frac_widths[1], - y_frac_widths[1], - ) - - try: - plot = self.plots[field] - except KeyError: - plot = PlotMPL(self.figure_size, axrect, None, None) - self.plots[field] = plot - return plot - def _setup_plots(self): if self._plot_valid: return @@ -315,13 +299,10 @@ def _setup_plots(self): else: unit_x = self._x_unit - if field in self._y_units: - unit_y = self._y_units[field] - else: - unit_y = y.units + unit_y = plot.norm_handler.display_units - x = x.to(unit_x) - y = y.to(unit_y) + x.convert_to_units(unit_x) + y.convert_to_units(unit_y) # determine legend label str_seq = [] @@ -334,11 +315,18 @@ def _setup_plots(self): plot.axes.plot(x, y, label=legend_label) # apply log transforms if requested - if self._field_transform[field] != linear_transform: - if (y <= 0).any(): - plot.axes.set_yscale("symlog") - else: - plot.axes.set_yscale("log") + norm = plot.norm_handler.get_norm(data=y) + y_norm_type = type(norm) + if y_norm_type is Normalize: + plot.axes.set_yscale("linear") + elif y_norm_type is LogNorm: + plot.axes.set_yscale("log") + elif y_norm_type is SymLogNorm: + plot.axes.set_yscale("symlog") + else: + raise NotImplementedError( + f"LinePlot doesn't support y norm with type {type(norm)}" + ) # set font properties plot._set_font_properties(self._font_properties, None) @@ -409,17 +397,18 @@ def set_x_unit(self, unit_name): self._x_unit = unit_name @invalidate_plot - def set_unit(self, field, unit_name): + def set_unit(self, field, new_unit): """Set the unit used to plot the field Parameters ---------- field: str or field tuple The name of the field to set the units for - unit_name: str - The name of the unit to use for this field + new_unit: string or Unit object """ - self._y_units[self.data_source._determine_fields(field)[0]] = unit_name + field = self.data_source._determine_fields(field)[0] + pnh = self.plots[field].norm_handler + pnh.display_units = new_unit @invalidate_plot def annotate_title(self, field, title): diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index c0699d131b..45010333c6 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -6,20 +6,24 @@ import warnings from collections import defaultdict from functools import wraps -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union -import numpy as np +import matplotlib +from matplotlib.colors import LogNorm, Normalize, SymLogNorm from matplotlib.font_manager import FontProperties -from more_itertools.more import always_iterable +from unyt.dimensions import length from yt._maintenance.deprecation import issue_deprecation_warning +from yt._typing import Quantity from yt.config import ytcfg from yt.data_objects.time_series import DatasetSeries -from yt.funcs import dictWithFactory, ensure_dir, is_sequence, iter_fields, mylog -from yt.units import YTQuantity +from yt.funcs import ensure_dir, is_sequence, iter_fields from yt.units.unit_object import Unit # type: ignore from yt.utilities.definitions import formatted_length_unit_names -from yt.utilities.exceptions import YTNotInsideNotebook +from yt.utilities.exceptions import YTConfigurationError, YTNotInsideNotebook +from yt.visualization._commons import get_default_from_config +from yt.visualization._handlers import ColorbarHandler, NormHandler +from yt.visualization.base_plot_types import PlotMPL from ._commons import ( DEFAULT_FONT_PROPERTIES, @@ -30,6 +34,11 @@ validate_plot, ) +if sys.version_info >= (3, 8): + from typing import Final, Literal +else: + from typing_extensions import Final, Literal + latex_prefixes = { "u": r"\mu", } @@ -71,82 +80,17 @@ def newfunc(self, field, *args, **kwargs): return newfunc -def get_log_minorticks(vmin, vmax): - """calculate positions of linear minorticks on a log colorbar - - Parameters - ---------- - vmin : float - the minimum value in the colorbar - vmax : float - the maximum value in the colorbar - - """ - expA = np.floor(np.log10(vmin)) - expB = np.floor(np.log10(vmax)) - cofA = np.ceil(vmin / 10**expA).astype("int64") - cofB = np.floor(vmax / 10**expB).astype("int64") - lmticks = [] - while cofA * 10**expA <= cofB * 10**expB: - if expA < expB: - lmticks = np.hstack((lmticks, np.linspace(cofA, 9, 10 - cofA) * 10**expA)) - cofA = 1 - expA += 1 - else: - lmticks = np.hstack( - (lmticks, np.linspace(cofA, cofB, cofB - cofA + 1) * 10**expA) - ) - expA += 1 - return np.array(lmticks) +# define a singleton sentinel to be used as default value distinct from None +class Unset: + _instance = None + def __new__(cls): + if cls._instance is None: + cls._instance = object.__new__(cls) + return cls._instance -def get_symlog_minorticks(linthresh, vmin, vmax): - """calculate positions of linear minorticks on a symmetric log colorbar - - Parameters - ---------- - linthresh : float - the threshold for the linear region - vmin : float - the minimum value in the colorbar - vmax : float - the maximum value in the colorbar - - """ - if vmin > 0: - return get_log_minorticks(vmin, vmax) - elif vmax < 0 and vmin < 0: - return -get_log_minorticks(-vmax, -vmin) - elif vmin == 0: - return np.hstack((0, get_log_minorticks(linthresh, vmax))) - elif vmax == 0: - return np.hstack((-get_log_minorticks(linthresh, -vmin)[::-1], 0)) - else: - return np.hstack( - ( - -get_log_minorticks(linthresh, -vmin)[::-1], - 0, - get_log_minorticks(linthresh, vmax), - ) - ) - -field_transforms = {} - - -class FieldTransform: - def __init__(self, name, func): - self.name = name - self.func = func - field_transforms[name] = self - - def __call__(self, *args, **kwargs): - return self.func(*args, **kwargs) - - -log_transform = FieldTransform("log10", np.log10) -linear_transform = FieldTransform("linear", lambda x: x) -symlog_transform = FieldTransform("symlog", None) +UNSET: Final = Unset() class PlotDictionary(defaultdict): @@ -173,66 +117,44 @@ def __init__(self, data_source, default_factory=None): class PlotContainer(abc.ABC): """A container for generic plots""" + _plot_dict_type: Type[PlotDictionary] = PlotDictionary _plot_type: Optional[str] = None _plot_valid = False - # Plot defaults - _colormap_config: dict - _log_config: dict - _units_config: dict + _default_figure_size = tuple(matplotlib.rcParams["figure.figsize"]) + _default_font_size = 14.0 - def __init__(self, data_source, figure_size, fontsize): + def __init__(self, data_source, figure_size=None, fontsize: Optional[float] = None): self.data_source = data_source self.ds = data_source.ds self.ts = self._initialize_dataset(self.ds) - if is_sequence(figure_size): - self.figure_size = float(figure_size[0]), float(figure_size[1]) - else: - self.figure_size = float(figure_size) + self.plots = self.__class__._plot_dict_type(data_source) + self._set_figure_size(figure_size) + + if fontsize is None: + fontsize = self.__class__._default_font_size if sys.version_info >= (3, 9): font_dict = DEFAULT_FONT_PROPERTIES | {"size": fontsize} else: - font_dict = {**DEFAULT_FONT_PROPERTIES, "size": fontsize} + font_dict = {**DEFAULT_FONT_PROPERTIES, "size": fontsize} # type:ignore self._font_properties = FontProperties(**font_dict) self._font_color = None self._xlabel = None self._ylabel = None - self._minorticks = {} - self._field_transform = {} - - self.setup_defaults() - - def setup_defaults(self): - def default_from_config(keys, defaults): - _keys = list(always_iterable(keys)) - _defaults = list(always_iterable(defaults)) - - def getter(field): - ftype, fname = self.data_source._determine_fields(field)[0] - ret = [ - ytcfg.get_most_specific("plot", ftype, fname, key, fallback=default) - for key, default in zip(_keys, _defaults) - ] - if len(ret) == 1: - return ret[0] - return ret - - return getter - - default_cmap = ytcfg.get("yt", "default_colormap") - self._colormap_config = dictWithFactory( - default_from_config("cmap", default_cmap) - )() - self._log_config = dictWithFactory( - default_from_config(["log", "linthresh"], [None, None]) - )() - self._units_config = dictWithFactory(default_from_config("units", [None]))() + self._minorticks: Dict[Tuple[str, str], bool] = {} @accepts_all_fields @invalidate_plot - def set_log(self, field, log, linthresh=None, symlog_auto=False): + def set_log( + self, + field, + log: Optional[bool] = None, + *, + linthresh: Optional[Union[float, Quantity, Literal["auto"]]] = None, + symlog_auto: Optional[bool] = None, # deprecated + ): """set a field to log, linear, or symlog. Symlog scaling is a combination of linear and log, where from 0 to a @@ -247,30 +169,64 @@ def set_log(self, field, log, linthresh=None, symlog_auto=False): field : string the field to set a transform if field == 'all', applies to all plots. - log : boolean - Log on/off: on means log scaling; off means linear scaling. Unless - a linthresh is set or symlog_auto is set in which case symlog is used. - linthresh : float, optional + log : boolean, optional + set log to True for log scaling, False for linear scaling. + linthresh : float, (float, str), unyt_quantity, or 'auto', optional when using symlog scaling, linthresh is the value at which scaling transitions from linear to logarithmic. linthresh must be positive. Note: setting linthresh will automatically enable symlog scale - symlog_auto : boolean - if symlog_auto is True, then yt will use symlog scaling and attempt to - determine a linthresh automatically. Setting a linthresh manually - overrides this value. + Note that *log* and *linthresh* are mutually exclusive arguments """ - if symlog_auto: - self._field_transform[field] = symlog_transform - if log: - self._field_transform[field] = log_transform - else: - self._field_transform[field] = linear_transform + if log is None and linthresh is None and symlog_auto is None: + raise TypeError("set_log requires log or linthresh be set") + + if symlog_auto is not None: + issue_deprecation_warning( + "the symlog_auto argument is deprecated. Use linthresh='auto' instead", + since="4.1", + stacklevel=5, + ) + if symlog_auto is True: + linthresh = "auto" + elif symlog_auto is False: + pass + else: + raise TypeError( + "Received invalid value for parameter symlog_auto. " + f"Expected a boolean, got {symlog_auto!r}" + ) + + if log is not None and linthresh is not None: + # we do not raise an error here for backward compatibility + warnings.warn( + f"log={log} has no effect because linthresh specified. Using symlog.", + stacklevel=4, + ) + + pnh = self.plots[field].norm_handler + if linthresh is not None: - if not linthresh > 0.0: - raise ValueError('"linthresh" must be positive') - self._field_transform[field] = symlog_transform - self._field_transform[field].func = linthresh + if isinstance(linthresh, str): + if linthresh == "auto": + pnh.norm_type = SymLogNorm + else: + raise ValueError( + "Expected a number, a unyt_quantity, a (float, 'unit') tuple, or 'auto'. " + f"Got linthresh={linthresh!r}" + ) + else: + # pnh takes care of switching to symlog when linthresh is set + pnh.linthresh = linthresh + elif log is True: + pnh.norm_type = LogNorm + elif log is False: + pnh.norm_type = Normalize + else: + raise TypeError( + f"Could not parse arguments log={log!r}, linthresh={linthresh!r}" + ) + return self def get_log(self, field): @@ -285,21 +241,61 @@ def get_log(self, field): """ # devnote : accepts_all_fields decorator is not applicable here because # the return variable isn't self + issue_deprecation_warning( + "The get_log method is not reliable and is deprecated. " + "Please do not rely on it.", + since="4.1", + ) log = {} if field == "all": fields = list(self.plots.keys()) else: fields = field for field in self.data_source._determine_fields(fields): - log[field] = self._field_transform[field] == log_transform + pnh = self.plots[field].norm_handler + if pnh.norm is not None: + log[field] = type(pnh.norm) is LogNorm + elif pnh.norm_type is not None: + log[field] = pnh.norm_type is LogNorm + else: + # the NormHandler object has no constraints yet + # so we'll assume defaults + log[field] = True return log @invalidate_plot - def set_transform(self, field, name): + def set_transform(self, field, name: str): field = self.data_source._determine_fields(field)[0] - if name not in field_transforms: - raise KeyError(name) - self._field_transform[field] = field_transforms[name] + pnh = self.plots[field].norm_handler + pnh.norm_type = { + "linear": Normalize, + "log10": LogNorm, + "symlog": SymLogNorm, + }[name] + return self + + @accepts_all_fields + @invalidate_plot + def set_norm(self, field, norm: Normalize): + r""" + Set a custom ``matplotlib.colors.Normalize`` to plot *field*. + + Any constraints previously set with `set_log`, `set_zlim` will be + dropped. + + Note that any float value attached to *norm* (e.g. vmin, vmax, + vcenter ...) will be read in the current displayed units, which can be + controlled with the `set_unit` method. + + Parameters + ---------- + field : str or tuple[str, str] + if field == 'all', applies to all plots. + norm : matplotlib.colors.Normalize + see https://matplotlib.org/stable/tutorials/colors/colormapnorms.html + """ + pnh = self.plots[field].norm_handler + pnh.norm = norm return self @accepts_all_fields @@ -322,6 +318,7 @@ def set_minorticks(self, field, state): self._minorticks[field] = state return self + @abc.abstractmethod def _setup_plots(self): # Left blank to be overridden in subclasses pass @@ -368,7 +365,6 @@ def _switch_ds(self, new_ds, data_source=None): lim = tuple(new_ds.quan(l.value, str(l.units)) for l in lim) setattr(self, lim_name, lim) self.plots.data_source = new_object - self._background_color.data_source = new_object self._colorbar_label.data_source = new_object self._setup_plots() @@ -465,6 +461,16 @@ def set_font_size(self, size): """ return self.set_font({"size": size}) + def _set_figure_size(self, size): + if size is None: + self.figure_size = self.__class__._default_figure_size + elif is_sequence(size): + if len(size) != 2: + raise TypeError(f"Expected a single float or a pair, got {size}") + self.figure_size = float(size[0]), float(size[1]) + else: + self.figure_size = float(size) + @invalidate_plot @invalidate_figure def set_figure_size(self, size): @@ -472,11 +478,13 @@ def set_figure_size(self, size): parameters ---------- - size : float - The size of the figure on the longest axis (in units of inches), - including the margins but not the colorbar. + size : float, a sequence of two floats, or None + The size of the figure (in units of inches), including the margins + but not the colorbar. If a single float is passed, it's interpreted + as the size along the long axis. + Pass None to reset """ - self.figure_size = float(size) + self._set_figure_size(size) return self @validate_plot @@ -538,8 +546,7 @@ def save( new_name = validate_image_name(name, suffix) if new_name == name: - # somehow mypy thinks we may not have a plots attr yet, hence we turn it off here - for v in self.plots.values(): # type: ignore + for v in self.plots.values(): out_name = v.save(name, mpl_kwargs) names.append(out_name) return names @@ -560,8 +567,7 @@ def save( if "Cutting" in self.data_source.__class__.__name__: plot_type = "OffAxisSlice" - # somehow mypy thinks we may not have a plots attr yet, hence we turn it off here - for k, v in self.plots.items(): # type: ignore + for k, v in self.plots.items(): if isinstance(k, tuple): k = k[1] @@ -796,7 +802,7 @@ def show_colorbar(self, field=None): self.plots[f].show_colorbar() return self - def hide_axes(self, field=None, draw_frame=False): + def hide_axes(self, field=None, draw_frame=None): """ Hides the axes for a plot and updates the size of the plot accordingly. Defaults to operating on all fields for a @@ -842,7 +848,7 @@ def hide_axes(self, field=None, draw_frame=False): if field is None: field = self.fields for f in iter_fields(field): - self.plots[f].hide_axes(draw_frame) + self.plots[f].hide_axes(draw_frame=draw_frame) return self def show_axes(self, field=None): @@ -864,19 +870,58 @@ def show_axes(self, field=None): return self -class ImagePlotContainer(PlotContainer): +class ImagePlotContainer(PlotContainer, abc.ABC): """A container for plots with colorbars.""" _colorbar_valid = False def __init__(self, data_source, figure_size, fontsize): super().__init__(data_source, figure_size, fontsize) - self.plots = PlotDictionary(data_source) self._callbacks = [] - self._cbar_minorticks = {} - self._background_color = PlotDictionary(self.data_source, lambda: "w") self._colorbar_label = PlotDictionary(self.data_source, lambda: None) + def _get_default_handlers( + self, field, default_display_units: Unit + ) -> Tuple[NormHandler, ColorbarHandler]: + + usr_units_str = get_default_from_config( + self.data_source, field=field, keys="units", defaults=[None] + ) + if usr_units_str is not None: + usr_units = Unit(usr_units_str) + d1 = usr_units.dimensions + d2 = default_display_units.dimensions + + if d1 == d2: + display_units = usr_units + elif getattr(self, "projected", False) and d2 / d1 == length: + path_length_units = Unit( + ytcfg.get_most_specific( + "plot", *field, "path_length_units", fallback="cm" + ), + registry=self.data_source.ds.unit_registry, + ) + display_units = usr_units * path_length_units + else: + raise YTConfigurationError( + f"Invalid units in configuration file for field {field!r}. " + f"Found {usr_units!r}" + ) + else: + display_units = default_display_units + + pnh = NormHandler(self.data_source, display_units=display_units) + + cbh = ColorbarHandler( + cmap=get_default_from_config( + self.data_source, + field=field, + keys="cmap", + defaults=[None], + ) + ) + return pnh, cbh + @accepts_all_fields @invalidate_plot def set_cmap(self, field, cmap): @@ -895,7 +940,7 @@ def set_cmap(self, field, cmap): """ self._colorbar_valid = False - self._colormap_config[field] = cmap + self.plots[field].colorbar_handler.cmap = cmap return self @accepts_all_fields @@ -914,19 +959,19 @@ def set_background_color(self, field, color=None): the color map """ - if color is None: - from yt.visualization.color_maps import _get_cmap - - cmap = self._colormap_config[field] - if isinstance(cmap, str): - cmap = _get_cmap(cmap) - color = cmap(0) - self._background_color[field] = color + cbh = self[field].colorbar_handler + cbh.background_color = color return self @accepts_all_fields @invalidate_plot - def set_zlim(self, field, zmin, zmax, dynamic_range=None): + def set_zlim( + self, + field, + zmin: Union[float, Quantity, Literal["min"], Unset] = UNSET, + zmax: Union[float, Quantity, Literal["max"], Unset] = UNSET, + dynamic_range: Optional[float] = None, + ): """set the scale of the colormap Parameters @@ -934,10 +979,10 @@ def set_zlim(self, field, zmin, zmax, dynamic_range=None): field : string the field to set a colormap scale if field == 'all', applies to all plots. - zmin : float, tuple, YTQuantity or str + zmin : float, Quantity, or 'min' the new minimum of the colormap scale. If 'min', will set to the minimum value in the current view. - zmax : float, tuple, YTQuantity or str + zmax : float, Quantity, or 'max' the new maximum of the colormap scale. If 'max', will set to the maximum value in the current view. @@ -947,50 +992,47 @@ def set_zlim(self, field, zmin, zmax, dynamic_range=None): The dynamic range of the image. If zmin == None, will set zmin = zmax / dynamic_range If zmax == None, will set zmax = zmin * dynamic_range - When dynamic_range is specified, defaults to setting - zmin = zmax / dynamic_range. """ + if zmin is UNSET and zmax is UNSET: + raise TypeError("Missing required argument zmin or zmax") + + if zmin is UNSET: + zmin = None + elif zmin is None: + # this sentinel value juggling is barely maintainable + # this use case is deprecated so we can simplify the logic here + # in the future and use `None` as the default value, + # instead of the custom sentinel UNSET + issue_deprecation_warning( + "Passing `zmin=None` explicitly is deprecated. " + "If you wish to explicitly set zmin to the minimal " + "data value, pass `zmin='min'` instead. " + "Otherwise leave this argument unset.", + since="4.1.0", + stacklevel=5, + ) + zmin = "min" + + if zmax is UNSET: + zmax = None + elif zmax is None: + # see above + issue_deprecation_warning( + "Passing `zmax=None` explicitly is deprecated. " + "If you wish to explicitly set zmax to the maximal " + "data value, pass `zmin='max'` instead. " + "Otherwise leave this argument unset.", + since="4.1.0", + stacklevel=5, + ) + zmax = "max" - def _sanitize_units(z, _field): - # convert dimensionful inputs to float - if isinstance(z, tuple): - z = self.ds.quan(*z) - if isinstance(z, YTQuantity): - try: - plot_units = self.frb[_field].units - z = z.to(plot_units).value - except AttributeError: - # only certain subclasses have a frb attribute - # they can rely on for inspecting units - mylog.warning( - "%s class doesn't support zmin/zmax" - " as tuples or unyt_quantity", - self.__class__.__name__, - ) - z = z.value - return z + pnh = self.plots[field].norm_handler + pnh.vmin = zmin + pnh.vmax = zmax + pnh.dynamic_range = dynamic_range - if field == "all": - fields = list(self.plots.keys()) - else: - fields = field - for field in self.data_source._determine_fields(fields): - myzmin = _sanitize_units(zmin, field) - myzmax = _sanitize_units(zmax, field) - if zmin == "min": - myzmin = self.plots[field].image._A.min() - if zmax == "max": - myzmax = self.plots[field].image._A.max() - if dynamic_range is not None: - if zmax is None: - myzmax = myzmin * dynamic_range - else: - myzmin = myzmax / dynamic_range - if myzmin > 0.0 and self._field_transform[field] == symlog_transform: - self._field_transform[field] = log_transform - self.plots[field].zmin = myzmin - self.plots[field].zmax = myzmax return self @accepts_all_fields @@ -1009,7 +1051,7 @@ def set_colorbar_minorticks(self, field, state): state : bool the state indicating 'on' (True) or 'off' (False) """ - self._cbar_minorticks[field] = state + self.plots[field].colorbar_handler.draw_minorticks = state return self @invalidate_plot @@ -1029,8 +1071,32 @@ def set_colorbar_label(self, field, label): ... ) """ + field = self.data_source._determine_fields(field) self._colorbar_label[field] = label return self def _get_axes_labels(self, field): return (self._xlabel, self._ylabel, self._colorbar_label[field]) + + +class BaseLinePlot(PlotContainer, abc.ABC): + + # A common ancestor to LinePlot and ProfilePlot + + @abc.abstractmethod + def _get_axrect(self): + pass + + def _get_plot_instance(self, field): + if field in self.plots: + return self.plots[field] + axrect = self._get_axrect() + + pnh = NormHandler(self.data_source, display_units=self.data_source[field].units) + finfo = self.data_source.ds._get_field_info(*field) + if not finfo.take_log: + pnh.norm_type = Normalize + plot = PlotMPL(self.figure_size, axrect, norm_handler=pnh) + self.plots[field] = plot + + return plot diff --git a/yt/visualization/plot_window.py b/yt/visualization/plot_window.py index 553ba039b2..98f0f96902 100644 --- a/yt/visualization/plot_window.py +++ b/yt/visualization/plot_window.py @@ -5,12 +5,11 @@ import matplotlib import numpy as np +from matplotlib.colors import Normalize from more_itertools import always_iterable -from packaging.version import Version from unyt.exceptions import UnitConversionError from yt._maintenance.deprecation import issue_deprecation_warning -from yt.config import ytcfg from yt.data_objects.image_array import ImageArray from yt.frontends.ytdata.data_structures import YTSpatialPlotDataset from yt.funcs import fix_axis, fix_unitary, is_sequence, iter_fields, mylog, obj_length @@ -26,10 +25,10 @@ ) from yt.utilities.math_utils import ortho_find from yt.utilities.orientation import Orientation -from yt.visualization.base_plot_types import CallbackWrapper +from yt.visualization._handlers import ColorbarHandler, NormHandler +from yt.visualization.base_plot_types import CallbackWrapper, ImagePlotMPL -from ._commons import MPL_VERSION, _swap_axes_extents -from .base_plot_types import ImagePlotMPL +from ._commons import _swap_axes_extents, get_default_from_config from .fixed_resolution import ( FixedResolutionBuffer, OffAxisProjectionFixedResolutionBuffer, @@ -37,29 +36,17 @@ from .geo_plot_utils import get_mpl_transform from .plot_container import ( ImagePlotContainer, - get_log_minorticks, - get_symlog_minorticks, invalidate_data, invalidate_figure, invalidate_plot, - linear_transform, - log_transform, - symlog_transform, ) import sys # isort: skip -if sys.version_info < (3, 10): - # this function is deprecated in more_itertools - # because it is superseded by the standard library - from more_itertools import zip_equal +if sys.version_info >= (3, 10): + pass else: - - def zip_equal(*args): - # FUTURE: when only Python 3.10+ is supported, - # drop this conditional and call the builtin zip - # function directly where due - return zip(*args, strict=True) + from yt._maintenance.backports import zip def get_window_parameters(axis, center, width, ds): @@ -137,7 +124,7 @@ def validate_mesh_fields(data_source, fields): raise YTInvalidFieldType(invalid_fields) -class PlotWindow(ImagePlotContainer): +class PlotWindow(ImagePlotContainer, abc.ABC): r""" A plotting mechanism based around the concept of a window into a data source. It can have arbitrary fields, each of which will be @@ -250,21 +237,30 @@ def __init__( self._projection = get_mpl_transform(projection) self._transform = get_mpl_transform(transform) + self._setup_plots() + for field in self.data_source._determine_fields(self.fields): finfo = self.data_source.ds._get_field_info(*field) - if finfo.take_log: - self._field_transform[field] = log_transform + pnh = self.plots[field].norm_handler + if finfo.take_log is False: + # take_log can be `None` so we explicitly compare against a boolean + pnh.norm_type = Normalize else: - self._field_transform[field] = linear_transform - - log, linthresh = self._log_config[field] - if log is not None: - self.set_log(field, log, linthresh=linthresh) - - # Access the dictionary to force the key to be created - self._units_config[field] + # do nothing, the norm handler is responsible for + # determining a viable norm, and defaults to LogNorm/SymLogNorm + pass - self._setup_plots() + # override from user configuration if any + log, linthresh = get_default_from_config( + self.data_source, + field=field, + keys=["log", "linthresh"], + defaults=[None, None], + ) + if linthresh is not None: + self.set_log(field, linthresh=linthresh) + elif log is not None: + self.set_log(field, log) def __iter__(self): for ds in self.ts: @@ -309,6 +305,7 @@ def _recreate_frb(self): old_filters = self._frb._filters # Set the bounds if hasattr(self, "zlim"): + # Support OffAxisProjectionPlot and OffAxisSlicePlot bounds = self.xlim + self.ylim + self.zlim else: bounds = self.xlim + self.ylim @@ -324,50 +321,7 @@ def _recreate_frb(self): ) # At this point the frb has the valid bounds, size, aliasing, etc. - if old_fields is None: - self._frb._get_data_source_fields() - - # New frb, apply default units (if any) - for field, field_unit in self._units_config.items(): - if field_unit is None: - continue - - field_unit = Unit(field_unit, registry=self.ds.unit_registry) - is_projected = getattr(self, "projected", False) - if is_projected: - # Obtain config - path_length_units = Unit( - ytcfg.get_most_specific( - "plot", *field, "path_length_units", fallback="cm" - ), - registry=self.ds.unit_registry, - ) - units = field_unit * path_length_units - else: - units = field_unit - try: - self.frb[field].convert_to_units(units) - except UnitConversionError: - msg = ( - "Could not apply default units from configuration.\n" - "Tried converting projected field %s from %s to %s, retaining units %s:\n" - "\tgot units for field: %s" - ) - args = [ - field, - self.frb[field].units, - units, - field_unit, - units, - ] - if is_projected: - msg += "\n\tgot units for integration length: %s" - args += [path_length_units] - - msg += "\nCheck your configuration file." - - mylog.error(msg, *args) - else: + if old_fields is not None: # Restore the old fields for key, units in zip(old_fields, old_units): self._frb[key] @@ -497,7 +451,6 @@ def set_unit(self, field, new_unit, equivalency=None, equivalency_kwargs=None): The name of the field that is to be changed. new_unit : string or Unit object - The name of the new unit. equivalency : string, optional If set, the equivalency to use to convert the current units to @@ -508,9 +461,11 @@ def set_unit(self, field, new_unit, equivalency=None, equivalency_kwargs=None): Keyword arguments to be passed to the equivalency. Only used if ``equivalency`` is set. """ - for f, u in zip_equal(iter_fields(field), always_iterable(new_unit)): + for f, u in zip(iter_fields(field), always_iterable(new_unit), strict=True): self.frb.set_unit(f, u, equivalency, equivalency_kwargs) self._equivalencies[f] = (equivalency, equivalency_kwargs) + pnh = self.plots[f].norm_handler + pnh.display_units = u return self @invalidate_plot @@ -676,6 +631,7 @@ def _set_window(self, bounds): self.xlim = tuple(bounds[0:2]) self.ylim = tuple(bounds[2:4]) if len(bounds) == 6: + # Support OffAxisProjectionPlot and OffAxisSlicePlot self.zlim = tuple(bounds[4:6]) mylog.info("xlim = %f %f", self.xlim[0], self.xlim[1]) mylog.info("ylim = %f %f", self.ylim[0], self.ylim[1]) @@ -1115,72 +1071,28 @@ def _setup_plots(self): extent = [*extentx, *extenty] + image = self.frb[f] + font_size = self._font_properties.get_size() + if f in self.plots.keys(): - zlim = (self.plots[f].zmin, self.plots[f].zmax) + pnh = self.plots[f].norm_handler + cbh = self.plots[f].colorbar_handler else: - zlim = (None, None) - - image = self.frb[f] - if self._field_transform[f] == log_transform: - msg = None - use_symlog = False - if zlim != (None, None): - pass - elif np.nanmax(image) == np.nanmin(image): - msg = f"Plotting {f}: All values = {np.nanmax(image)}" - elif np.nanmax(image) <= 0: - msg = ( - f"Plotting {f}: All negative values. Max = {np.nanmax(image)}." - ) - use_symlog = True - elif not np.any(np.isfinite(image)): - msg = f"Plotting {f}: All values = NaN." - elif np.nanmax(image) > 0.0 and np.nanmin(image) <= 0: - msg = ( - f"Plotting {f}: Both positive and negative values. " - f"Min = {np.nanmin(image)}, Max = {np.nanmax(image)}." + pnh, cbh = self._get_default_handlers( + field=f, default_display_units=image.units + ) + if pnh.display_units != image.units: + equivalency, equivalency_kwargs = self._equivalencies[f] + image.convert_to_units( + pnh.display_units, equivalency, **equivalency_kwargs ) - use_symlog = True - elif ( - (Version("3.3") <= MPL_VERSION < Version("3.5")) - and np.nanmax(image) > 0.0 - and np.nanmin(image) == 0 - ): - # normally, a LogNorm scaling would still be OK here because - # LogNorm will mask 0 values when calculating vmin. But - # due to a bug in matplotlib's imshow, if the data range - # spans many orders of magnitude while containing zero points - # vmin can get rescaled to 0, resulting in an error when the image - # gets drawn. So here we switch to symlog to avoid that until - # a fix is in -- see PR #3161 and linked issue. - cutoff_sigdigs = 15 - if ( - np.log10(np.nanmax(image[np.isfinite(image)])) - - np.log10(np.nanmin(image[image > 0])) - > cutoff_sigdigs - ): - msg = f"Plotting {f}: Wide range and zeros." - use_symlog = True - if msg is not None: - mylog.warning(msg) - if use_symlog: - mylog.warning("Switching to symlog colorbar scaling.") - self._field_transform[f] = symlog_transform - self._field_transform[f].func = None - else: - mylog.warning("Switching to linear colorbar scaling.") - self._field_transform[f] = linear_transform - - font_size = self._font_properties.get_size() fig = None axes = None cax = None - draw_colorbar = True draw_axes = True - draw_frame = draw_axes + draw_frame = None if f in self.plots: - draw_colorbar = self.plots[f]._draw_colorbar draw_axes = self.plots[f]._draw_axes draw_frame = self.plots[f]._draw_frame if self.plots[f].figure is not None: @@ -1212,11 +1124,7 @@ def _setup_plots(self): self.plots[f] = WindowPlotMPL( ia, - self._field_transform[f].name, - self._field_transform[f].func, - self._colormap_config[f], extent, - zlim, self.figure_size, font_size, aspect, @@ -1225,6 +1133,8 @@ def _setup_plots(self): cax, self._projection, self._transform, + norm_handler=pnh, + colorbar_handler=cbh, ) axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y) @@ -1282,10 +1192,6 @@ def _setup_plots(self): self.plots[f].axes.set_xlabel(labels[0]) self.plots[f].axes.set_ylabel(labels[1]) - color = self._background_color[f] - - self.plots[f].axes.set_facecolor(color) - # Determine the units of the data units = Unit(self.frb[f].units, registry=self.ds.unit_registry) units = units.latex_representation() @@ -1317,63 +1223,9 @@ def _setup_plots(self): else: self.plots[f].axes.minorticks_off() - # colorbar minorticks - if f not in self._cbar_minorticks: - self._cbar_minorticks[f] = True - - if self._cbar_minorticks[f]: - vmin = np.float64(self.plots[f].cb.norm.vmin) - vmax = np.float64(self.plots[f].cb.norm.vmax) - - if self._field_transform[f] == linear_transform: - self.plots[f].cax.minorticks_on() - - elif self._field_transform[f] == symlog_transform: - if Version("3.2.0") <= MPL_VERSION < Version("3.5.0b"): - # no known working method to draw symlog minor ticks - # see https://github.com/yt-project/yt/issues/3535 - pass - else: - flinthresh = 10 ** np.floor( - np.log10(self.plots[f].cb.norm.linthresh) - ) - absmax = np.abs((vmin, vmax)).max() - if (absmax - flinthresh) / absmax < 0.1: - flinthresh /= 10 - mticks = get_symlog_minorticks(flinthresh, vmin, vmax) - if MPL_VERSION < Version("3.5.0b"): - # https://github.com/matplotlib/matplotlib/issues/21258 - mticks = self.plots[f].image.norm(mticks) - self.plots[f].cax.yaxis.set_ticks(mticks, minor=True) - - elif self._field_transform[f] == log_transform: - if MPL_VERSION >= Version("3.0.0"): - self.plots[f].cax.minorticks_on() - self.plots[f].cax.xaxis.set_visible(False) - else: - mticks = self.plots[f].image.norm( - get_log_minorticks(vmin, vmax) - ) - self.plots[f].cax.yaxis.set_ticks(mticks, minor=True) - - else: - mylog.error( - "Unable to draw cbar minorticks for field " - "%s with transform %s ", - f, - self._field_transform[f], - ) - self._cbar_minorticks[f] = False - - if not self._cbar_minorticks[f]: - self.plots[f].cax.minorticks_off() - if not draw_axes: self.plots[f]._toggle_axes(draw_axes, draw_frame) - if not draw_colorbar: - self.plots[f]._toggle_colorbar(draw_colorbar) - self._set_font_properties() self.run_callbacks() @@ -2652,11 +2504,7 @@ class WindowPlotMPL(ImagePlotMPL): def __init__( self, data, - cbname, - cblinthresh, - cmap, extent, - zlim, figure_size, fontsize, aspect, @@ -2665,55 +2513,41 @@ def __init__( cax, mpl_proj, mpl_transform, + *, + norm_handler: NormHandler, + colorbar_handler: ColorbarHandler, ): - from matplotlib.ticker import ScalarFormatter - - self._draw_colorbar = True - self._draw_axes = True - self._draw_frame = True - self._fontsize = fontsize - self._figure_size = figure_size self._projection = mpl_proj self._transform = mpl_transform + self._setup_layout_constraints(figure_size, fontsize) + self._draw_frame = True + self._aspect = ((extent[1] - extent[0]) / (extent[3] - extent[2])).in_cgs() + self._unit_aspect = aspect + # Compute layout - fontscale = float(fontsize) / 18.0 + self._figure_size = figure_size + self._draw_axes = True + fontscale = float(fontsize) / self.__class__._default_font_size if fontscale < 1.0: fontscale = np.sqrt(fontscale) if is_sequence(figure_size): - fsize = figure_size[0] + self._cb_size = 0.0375 * figure_size[0] else: - fsize = figure_size - self._cb_size = 0.0375 * fsize + self._cb_size = 0.0375 * figure_size self._ax_text_size = [1.2 * fontscale, 0.9 * fontscale] self._top_buff_size = 0.30 * fontscale - self._aspect = ((extent[1] - extent[0]) / (extent[3] - extent[2])).in_cgs() - self._unit_aspect = aspect - - size, axrect, caxrect = self._get_best_layout() - - super().__init__(size, axrect, caxrect, zlim, figure, axes, cax) - self._init_image(data, cbname, cblinthresh, cmap, extent, aspect) + super().__init__( + figure=figure, + axes=axes, + cax=cax, + norm_handler=norm_handler, + colorbar_handler=colorbar_handler, + ) - # In matplotlib 2.1 and newer we'll be able to do this using - # self.image.axes.ticklabel_format - # See https://github.com/matplotlib/matplotlib/pull/6337 - formatter = ScalarFormatter(useMathText=True) - formatter.set_scientific(True) - formatter.set_powerlimits((-2, 3)) - self.image.axes.xaxis.set_major_formatter(formatter) - self.image.axes.yaxis.set_major_formatter(formatter) - if cbname == "linear": - self.cb.formatter.set_scientific(True) - try: - self.cb.formatter.set_useMathText(True) - except AttributeError: - # this is only available in mpl > 2.1 - pass - self.cb.formatter.set_powerlimits((-2, 3)) - self.cb.update_ticks() + self._init_image(data, extent, aspect) def _create_axes(self, axrect): self.axes = self.figure.add_axes(axrect, projection=self._projection) diff --git a/yt/visualization/profile_plotter.py b/yt/visualization/profile_plotter.py index 5453757831..17855e23d6 100644 --- a/yt/visualization/profile_plotter.py +++ b/yt/visualization/profile_plotter.py @@ -1,33 +1,27 @@ import base64 import builtins import os -from collections import OrderedDict from functools import wraps -from typing import Any, Dict, Optional +from typing import Any, Dict, Iterable, Optional, Tuple, Union import matplotlib import numpy as np -from matplotlib.font_manager import FontProperties from more_itertools.more import always_iterable, unzip -from packaging.version import Version from yt.data_objects.profiles import create_profile, sanitize_field_tuple_keys from yt.data_objects.static_output import Dataset from yt.frontends.ytdata.data_structures import YTProfileDataset -from yt.funcs import is_sequence, iter_fields, matplotlib_style_context +from yt.funcs import iter_fields, matplotlib_style_context from yt.utilities.exceptions import YTNotInsideNotebook -from yt.utilities.logger import ytLogger as mylog +from yt.visualization._handlers import ColorbarHandler, NormHandler +from yt.visualization.base_plot_types import ImagePlotMPL, PlotMPL from ..data_objects.selection_objects.data_selection_objects import YTSelectionContainer -from ._commons import DEFAULT_FONT_PROPERTIES, MPL_VERSION, validate_image_name -from .base_plot_types import ImagePlotMPL, PlotMPL +from ._commons import validate_image_name from .plot_container import ( + BaseLinePlot, ImagePlotContainer, - PlotContainer, - get_log_minorticks, invalidate_plot, - linear_transform, - log_transform, validate_plot, ) @@ -42,42 +36,6 @@ def newfunc(*args, **kwargs): return newfunc -class PlotContainerDict(OrderedDict): - def __missing__(self, key): - plot = PlotMPL((10, 8), [0.1, 0.1, 0.8, 0.8], None, None) - self[key] = plot - return self[key] - - -class FigureContainer(OrderedDict): - def __init__(self, plots): - self.plots = plots - super().__init__() - - def __missing__(self, key): - self[key] = self.plots[key].figure - return self[key] - - def __iter__(self): - return iter(self.plots) - - -class AxesContainer(OrderedDict): - def __init__(self, plots): - self.plots = plots - self.ylim = {} - self.xlim = (None, None) - super().__init__() - - def __missing__(self, key): - self[key] = self.plots[key].axes - return self[key] - - def __setitem__(self, key, value): - super().__setitem__(key, value) - self.ylim[key] = (None, None) - - def sanitize_label(labels, nprofiles): labels = list(always_iterable(labels)) or [None] @@ -114,7 +72,7 @@ def data_object_or_all_data(data_source): return data_source -class ProfilePlot(PlotContainer): +class ProfilePlot(BaseLinePlot): r""" Create a 1d profile plot from a data source or from a list of profile objects. @@ -217,6 +175,8 @@ class ProfilePlot(PlotContainer): Use set_line_property to change line properties of one or all profiles. """ + _default_figure_size = (10.0, 8.0) + _default_font_size = 18.0 x_log = None y_log = None @@ -238,7 +198,6 @@ def __init__( x_log=True, y_log=True, ): - data_source = data_object_or_all_data(data_source) y_fields = list(iter_fields(y_fields)) logs = {x_field: bool(x_log)} @@ -266,7 +225,42 @@ def __init__( if not isinstance(plot_spec, list): plot_spec = [plot_spec.copy() for p in profiles] - ProfilePlot._initialize_instance(self, profiles, label, plot_spec, y_log) + ProfilePlot._initialize_instance( + self, data_source, profiles, label, plot_spec, y_log + ) + + @classmethod + def _initialize_instance( + cls, + obj, + data_source, + profiles, + labels, + plot_specs, + y_log, + ): + obj._plot_title = {} + obj._plot_text = {} + obj._text_xpos = {} + obj._text_ypos = {} + obj._text_kwargs = {} + + super(ProfilePlot, obj).__init__(data_source) + obj.profiles = list(always_iterable(profiles)) + obj.x_log = None + obj.y_log = sanitize_field_tuple_keys(y_log, data_source) or {} + obj.y_title = {} + obj.x_title = None + obj.label = sanitize_label(labels, len(obj.profiles)) + if plot_specs is None: + plot_specs = [dict() for p in obj.profiles] + obj.plot_spec = plot_specs + obj._xlim = (None, None) + obj._setup_plots() + return obj + + def _get_axrect(self): + return (0.1, 0.1, 0.8, 0.8) @validate_plot def save( @@ -291,13 +285,14 @@ def save( if not self._plot_valid: self._setup_plots() - # Mypy is hardly convinced that we have a `plots` and a `profile` attr + # Mypy is hardly convinced that we have a `profiles` attribute # at this stage, so we're lasily going to deactivate it locally - unique = set(self.plots.values()) # type: ignore - if len(unique) < len(self.plots): # type: ignore - iters = zip(range(len(unique)), sorted(unique)) + unique = set(self.plots.values()) + iters: Iterable[Tuple[Union[int, Tuple[str, str]], PlotMPL]] + if len(unique) < len(self.plots): + iters = enumerate(sorted(unique)) else: - iters = self.plots.items() # type: ignore + iters = self.plots.items() if name is None: if len(self.profiles) == 1: # type: ignore @@ -314,11 +309,10 @@ def save( names = [] for uid, plot in iters: - if isinstance(uid, tuple): # type: ignore + if isinstance(uid, tuple): uid = uid[1] # type: ignore uid_name = f"{prefix}_1d-Profile_{xfn}_{uid}{suffix}" names.append(uid_name) - mylog.info("Saving %s", uid_name) with matplotlib_style_context(): plot.save(uid_name, mpl_kwargs=mpl_kwargs) return names @@ -373,10 +367,10 @@ def _repr_html_(self): def _setup_plots(self): if self._plot_valid: return - for f in self.axes: - self.axes[f].cla() + for f, p in self.plots.items(): + p.axes.cla() if f in self._plot_text: - self.plots[f].axes.text( + p.axes.text( self._text_xpos[f], self._text_ypos[f], self._plot_text[f], @@ -387,7 +381,8 @@ def _setup_plots(self): for i, profile in enumerate(self.profiles): for field, field_data in profile.items(): - self.axes[field].plot( + plot = self._get_plot_instance(field) + plot.axes.plot( np.array(profile.x), np.array(field_data), label=self.label[i], @@ -396,7 +391,7 @@ def _setup_plots(self): for profile in self.profiles: for fname in profile.keys(): - axes = self.axes[fname] + axes = self.plots[fname].axes xscale, yscale = self._get_field_log(fname, profile) xtitle, ytitle = self._get_field_title(fname, profile) @@ -406,8 +401,10 @@ def _setup_plots(self): axes.set_ylabel(ytitle) axes.set_xlabel(xtitle) - axes.set_ylim(*self.axes.ylim[fname]) - axes.set_xlim(*self.axes.xlim) + pnh = self.plots[fname].norm_handler + + axes.set_ylim(pnh.vmin, pnh.vmax) + axes.set_xlim(*self._xlim) if fname in self._plot_title: axes.set_title(self._plot_title[fname]) @@ -417,31 +414,6 @@ def _setup_plots(self): self._set_font_properties() self._plot_valid = True - @classmethod - def _initialize_instance(cls, obj, profiles, labels, plot_specs, y_log): - obj._plot_title = {} - obj._plot_text = {} - obj._text_xpos = {} - obj._text_ypos = {} - obj._text_kwargs = {} - - obj._font_properties = FontProperties(**DEFAULT_FONT_PROPERTIES) - obj._font_color = None - obj.profiles = list(always_iterable(profiles)) - obj.x_log = None - obj.y_log = sanitize_field_tuple_keys(y_log, obj.profiles[0].data_source) or {} - obj.y_title = {} - obj.x_title = None - obj.label = sanitize_label(labels, len(obj.profiles)) - if plot_specs is None: - plot_specs = [dict() for p in obj.profiles] - obj.plot_spec = plot_specs - obj.plots = PlotContainerDict() - obj.figures = FigureContainer(obj.plots) - obj.axes = AxesContainer(obj.plots) - obj._setup_plots() - return obj - @classmethod def from_profiles(cls, profiles, labels=None, plot_specs=None, y_log=None): r""" @@ -495,7 +467,10 @@ def from_profiles(cls, profiles, labels=None, plot_specs=None, y_log=None): "Profiles list and plot_specs list must be the same size." ) obj = cls.__new__(cls) - return cls._initialize_instance(obj, profiles, labels, plot_specs, y_log) + profiles = list(always_iterable(profiles)) + return cls._initialize_instance( + obj, profiles[0].data_source, profiles, labels, plot_specs, y_log + ) @invalidate_plot def set_line_property(self, property, value, index=None): @@ -642,7 +617,7 @@ def set_xlim(self, xmin=None, xmax=None): >>> pp.save() """ - self.axes.xlim = (xmin, xmax) + self._xlim = (xmin, xmax) for i, p in enumerate(self.profiles): if xmin is None: xmi = p.x_bins.min() @@ -707,12 +682,14 @@ def set_ylim(self, field, ymin=None, ymax=None): >>> pp.save() """ - fields = list(self.axes.keys()) if field == "all" else field + fields = list(self.plots.keys()) if field == "all" else field for profile in self.profiles: for field in profile.data_source._determine_fields(fields): if field in profile.field_map: field = profile.field_map[field] - self.axes.ylim[field] = (ymin, ymax) + pnh = self.plots[field].norm_handler + pnh.vmin = ymin + pnh.vmax = ymax # Continue on to the next profile. break return self @@ -791,7 +768,7 @@ def annotate_title(self, title, field="all"): ... ) """ - fields = list(self.axes.keys()) if field == "all" else field + fields = list(self.plots.keys()) if field == "all" else field for profile in self.profiles: for field in profile.data_source._determine_fields(fields): if field in profile.field_map: @@ -843,7 +820,7 @@ def annotate_text(self, xpos=0.0, ypos=0.0, text=None, field="all", **text_kwarg >>> plot.save() """ - fields = list(self.axes.keys()) if field == "all" else field + fields = list(self.plots.keys()) if field == "all" else field for profile in self.profiles: for field in profile.data_source._determine_fields(fields): if field in profile.field_map: @@ -1057,10 +1034,6 @@ def _get_field_log(self, field_z, profile): scales = {True: "log", False: "linear"} return scales[x_log], scales[y_log], scales[z_log] - def _recreate_frb(self): - # needed for API compatibility with PlotWindow - pass - @property def profile(self): if not self._profile_valid: @@ -1078,45 +1051,26 @@ def _setup_plots(self): fig = None axes = None cax = None - draw_colorbar = True draw_axes = True - zlim = (None, None) xlim = self._xlim ylim = self._ylim if f in self.plots: - draw_colorbar = self.plots[f]._draw_colorbar + pnh = self.plots[f].norm_handler + cbh = self.plots[f].colorbar_handler draw_axes = self.plots[f]._draw_axes - zlim = (self.plots[f].zmin, self.plots[f].zmax) if self.plots[f].figure is not None: fig = self.plots[f].figure axes = self.plots[f].axes cax = self.plots[f].cax + else: + pnh, cbh = self._get_default_handlers( + field=f, default_display_units=self.profile[f].units + ) x_scale, y_scale, z_scale = self._get_field_log(f, self.profile) x_title, y_title, z_title = self._get_field_title(f, self.profile) - if zlim == (None, None): - if z_scale == "log": - positive_values = data[data > 0.0] - if len(positive_values) == 0: - mylog.warning( - "Profiled field %s has no positive values. Max = %f.", - f, - np.nanmax(data), - ) - mylog.warning("Switching to linear colorbar scaling.") - zmin = np.nanmin(data) - z_scale = "linear" - self._field_transform[f] = linear_transform - else: - zmin = positive_values.min() - self._field_transform[f] = log_transform - else: - zmin = np.nanmin(data) - self._field_transform[f] = linear_transform - zlim = [zmin, np.nanmax(data)] - font_size = self._font_properties.get_size() f = self.profile.data_source._determine_fields(f)[0] @@ -1124,9 +1078,7 @@ def _setup_plots(self): # override the colorbar here. splat_color = getattr(self, "splat_color", None) if splat_color is not None: - cmap = matplotlib.colors.ListedColormap(splat_color, "dummy") - else: - cmap = self._colormap_config[f] + cbh.cmap = matplotlib.colors.ListedColormap(splat_color, "dummy") masked_data = data.copy() masked_data[~self.profile.used] = np.nan @@ -1136,19 +1088,18 @@ def _setup_plots(self): masked_data, x_scale, y_scale, - z_scale, - cmap, - zlim, self.figure_size, font_size, fig, axes, cax, shading=self._shading, + norm_handler=pnh, + colorbar_handler=cbh, ) self.plots[f]._toggle_axes(draw_axes) - self.plots[f]._toggle_colorbar(draw_colorbar) + self.plots[f]._toggle_colorbar(cbh.draw_cbar) self.plots[f].axes.xaxis.set_label_text(x_title) self.plots[f].axes.yaxis.set_label_text(y_title) @@ -1157,10 +1108,6 @@ def _setup_plots(self): self.plots[f].axes.set_xlim(xlim) self.plots[f].axes.set_ylim(ylim) - color = self._background_color[f] - - self.plots[f].axes.set_facecolor(color) - if f in self._plot_text: self.plots[f].axes.text( self._text_xpos[f], @@ -1181,25 +1128,6 @@ def _setup_plots(self): else: self.plots[f].axes.minorticks_off() - # colorbar minorticks - if f not in self._cbar_minorticks: - self._cbar_minorticks[f] = True - if self._cbar_minorticks[f]: - if self._field_transform[f] == linear_transform: - self.plots[f].cax.minorticks_on() - elif MPL_VERSION < Version("3.0.0"): - # before matplotlib 3 log-scaled colorbars internally used - # a linear scale going from zero to one and did not draw - # minor ticks. Since we want minor ticks, calculate - # where the minor ticks should go in this linear scale - # and add them manually. - vmin = np.float64(self.plots[f].cb.norm.vmin) - vmax = np.float64(self.plots[f].cb.norm.vmax) - mticks = self.plots[f].image.norm(get_log_minorticks(vmin, vmax)) - self.plots[f].cax.yaxis.set_ticks(mticks, minor=True) - else: - self.plots[f].cax.minorticks_off() - self._set_font_properties() # if this is a particle plot with one color only, hide the cbar here @@ -1423,7 +1351,7 @@ def set_log(self, field, log): self.y_log = log self._profile_valid = False elif field in p.field_data: - self.z_log[field] = log + super().set_log(field, log) else: raise KeyError(f"Field {field} not in phase plot!") return self @@ -1447,7 +1375,7 @@ def set_unit(self, field, unit): self.profile.set_y_unit(unit) elif fd in self.profile.field_data.keys(): self.profile.set_field_unit(field, unit) - self.plots[field].zmin, self.plots[field].zmax = (None, None) + self.plots[field].norm_handler.display_units = unit else: raise KeyError(f"Field {field} not in phase plot!") return self @@ -1578,50 +1506,47 @@ def __init__( data, x_scale, y_scale, - z_scale, - cmap, - zlim, figure_size, fontsize, figure, axes, cax, shading="nearest", + *, + norm_handler: NormHandler, + colorbar_handler: ColorbarHandler, ): self._initfinished = False - self._draw_colorbar = True - self._draw_axes = True - self._figure_size = figure_size self._shading = shading - # Compute layout - fontscale = float(fontsize) / 18.0 - if fontscale < 1.0: - fontscale = np.sqrt(fontscale) - - if is_sequence(figure_size): - self._cb_size = 0.0375 * figure_size[0] - else: - self._cb_size = 0.0375 * figure_size - self._ax_text_size = [1.1 * fontscale, 0.9 * fontscale] - self._top_buff_size = 0.30 * fontscale - self._aspect = 1.0 - - size, axrect, caxrect = self._get_best_layout() - - super().__init__(size, axrect, caxrect, zlim, figure, axes, cax) + self._setup_layout_constraints(figure_size, fontsize) + + # this line is added purely to prevent exact image comparison tests + # to fail, but eventually we should embrace the change and + # use similar values for PhasePlotMPL and WindowPlotMPL + self._ax_text_size[0] *= 1.1 / 1.2 # TODO: remove this + + super().__init__( + figure=figure, + axes=axes, + cax=cax, + norm_handler=norm_handler, + colorbar_handler=colorbar_handler, + ) - self._init_image(x_data, y_data, data, x_scale, y_scale, z_scale, zlim, cmap) + self._init_image(x_data, y_data, data, x_scale, y_scale) self._initfinished = True def _init_image( - self, x_data, y_data, image_data, x_scale, y_scale, z_scale, zlim, cmap + self, + x_data, + y_data, + image_data, + x_scale, + y_scale, ): """Store output of imshow in image variable""" - if z_scale == "log": - norm = matplotlib.colors.LogNorm(zlim[0], zlim[1]) - elif z_scale == "linear": - norm = matplotlib.colors.Normalize(zlim[0], zlim[1]) + norm = self.norm_handler.get_norm(image_data) self.image = None self.cb = None @@ -1630,16 +1555,10 @@ def _init_image( np.array(y_data), np.array(image_data.T), norm=norm, - cmap=cmap, + cmap=self.colorbar_handler.cmap, shading=self._shading, ) + self._set_axes(norm) self.axes.set_xscale(x_scale) self.axes.set_yscale(y_scale) - self.cb = self.figure.colorbar(self.image, self.cax) - if z_scale == "linear": - self.cb.formatter.set_scientific(True) - self.cb.formatter.set_powerlimits((-2, 3)) - self.cb.update_ticks() - - self.cax.tick_params(which="both", axis="y", direction="in") diff --git a/yt/visualization/tests/test_norm_api_custom_norm.py b/yt/visualization/tests/test_norm_api_custom_norm.py new file mode 100644 index 0000000000..14e28581a9 --- /dev/null +++ b/yt/visualization/tests/test_norm_api_custom_norm.py @@ -0,0 +1,31 @@ +import matplotlib +from nose.plugins.attrib import attr +from packaging.version import Version + +from yt.testing import ANSWER_TEST_TAG, fake_random_ds, skipif +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import SlicePlot + +MPL_VERSION = Version(matplotlib.__version__) + + +@skipif( + MPL_VERSION < Version("3.2"), + reason=f"TwoSlopeNorm requires MPL 3.2, we have {MPL_VERSION}", +) +@attr(ANSWER_TEST_TAG) +def test_sliceplot_custom_norm(): + from matplotlib.colors import TwoSlopeNorm + + ds = fake_random_ds(16) + + def create_image(filename_prefix): + field = ("gas", "density") + p = SlicePlot(ds, "z", field) + p.set_norm(field, norm=(TwoSlopeNorm(vcenter=0, vmin=-0.5, vmax=1))) + p.save(f"{filename_prefix}") + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_sliceplot_custom_norm" + test.answer_name = "sliceplot_custom_norm" + yield test diff --git a/yt/visualization/tests/test_norm_api_inf_zlim.py b/yt/visualization/tests/test_norm_api_inf_zlim.py new file mode 100644 index 0000000000..17c89e3d16 --- /dev/null +++ b/yt/visualization/tests/test_norm_api_inf_zlim.py @@ -0,0 +1,39 @@ +import numpy as np +from nose.plugins.attrib import attr + +from yt.loaders import load_uniform_grid +from yt.testing import ANSWER_TEST_TAG +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import SlicePlot + + +@attr(ANSWER_TEST_TAG) +def test_inf_and_finite_values_zlim(): + # see https://github.com/yt-project/yt/issues/3901 + shape = (32, 16, 1) + a = np.ones(16) + b = np.ones((32, 16)) + c = np.reshape(a * b, shape) + + # injecting an inf + c[0, 0, 0] = np.inf + + data = {("gas", "density"): c} + + ds = load_uniform_grid( + data, + shape, + bbox=np.array([[0, 1], [0, 1], [0, 1]]), + ) + + def create_image(filename_prefix): + p = SlicePlot(ds, "z", ("gas", "density")) + + # setting zlim manually + p.set_zlim(("gas", "density"), -10, 10) + p.save(filename_prefix) + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_inf_and_finite_values_zlim" + test.answer_name = "inf_and_finite_values_zlim" + yield test diff --git a/yt/visualization/tests/test_norm_api_lineplot.py b/yt/visualization/tests/test_norm_api_lineplot.py new file mode 100644 index 0000000000..3da3b7aa96 --- /dev/null +++ b/yt/visualization/tests/test_norm_api_lineplot.py @@ -0,0 +1,32 @@ +from nose.plugins.attrib import attr + +from yt.testing import ANSWER_TEST_TAG, fake_random_ds +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import LinePlot + + +@attr(ANSWER_TEST_TAG) +def test_lineplot_set_axis_properties(): + ds = fake_random_ds(16) + + def create_image(filename_prefix): + p = LinePlot( + ds, + ("gas", "density"), + start_point=[0, 0, 0], + end_point=[1, 1, 1], + npoints=32, + ) + p.set_x_unit("cm") + p.save(f"{filename_prefix}_xunit") + + p.set_unit(("gas", "density"), "kg/cm**3") + p.save(f"{filename_prefix}_xunit_zunit") + + p.set_log(("gas", "density"), False) + p.save(f"{filename_prefix}_xunit_zunit_lin") + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_lineplot_set_axis_properties" + test.answer_name = "lineplot_set_axis_properties" + yield test diff --git a/yt/visualization/tests/test_norm_api_particleplot.py b/yt/visualization/tests/test_norm_api_particleplot.py new file mode 100644 index 0000000000..0cd2167a29 --- /dev/null +++ b/yt/visualization/tests/test_norm_api_particleplot.py @@ -0,0 +1,29 @@ +from nose.plugins.attrib import attr + +from yt.testing import ANSWER_TEST_TAG, fake_particle_ds +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import ParticleProjectionPlot + + +@attr(ANSWER_TEST_TAG) +def test_particleprojectionplot_set_colorbar_properties(): + ds = fake_particle_ds(npart=100) + + def create_image(filename_prefix): + field = ("all", "particle_mass") + p = ParticleProjectionPlot(ds, 2, field) + p.set_buff_size(10) + + p.set_unit(field, "Msun") + p.save(f"{filename_prefix}_set_unit") + + p.set_zlim(field, zmax=1e-35) + p.save(f"{filename_prefix}_set_unit_zlim") + + p.set_log(field, False) + p.save(f"{filename_prefix}_set_unit_zlim_log") + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_particleprojectionplot_set_colorbar_properties" + test.answer_name = "particleprojectionplot_set_colorbar_properties" + yield test diff --git a/yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_explicit.py b/yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_explicit.py new file mode 100644 index 0000000000..8125040b6a --- /dev/null +++ b/yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_explicit.py @@ -0,0 +1,33 @@ +from nose.plugins.attrib import attr + +from yt.testing import ANSWER_TEST_TAG, add_noise_fields, fake_random_ds +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import PhasePlot + + +@attr(ANSWER_TEST_TAG) +def test_phaseplot_set_colorbar_properties_explicit(): + ds = fake_random_ds(16) + add_noise_fields(ds) + + def create_image(filename_prefix): + my_sphere = ds.sphere("c", 1) + p = PhasePlot( + my_sphere, + ("gas", "noise1"), + ("gas", "noise3"), + [("gas", "density")], + weight_field=None, + ) + # using explicit units, we expect the colorbar units to stay unchanged + p.set_zlim(("gas", "density"), zmin=(1e36, "kg/AU**3")) + p.save(f"{filename_prefix}_set_zlim_explicit") + + # ... until we set them explicitly + p.set_unit(("gas", "density"), "kg/AU**3") + p.save(f"{filename_prefix}_set_zlim_set_unit_explicit") + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_phaseplot_set_colorbar_properties_explicit" + test.answer_name = "phaseplot_set_colorbar_properties_explicit" + yield test diff --git a/yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_implicit.py b/yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_implicit.py new file mode 100644 index 0000000000..6bb1f29513 --- /dev/null +++ b/yt/visualization/tests/test_norm_api_phaseplot_set_colorbar_implicit.py @@ -0,0 +1,33 @@ +from nose.plugins.attrib import attr + +from yt.testing import ANSWER_TEST_TAG, add_noise_fields, fake_random_ds +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import PhasePlot + + +@attr(ANSWER_TEST_TAG) +def test_phaseplot_set_colorbar_properties_implicit(): + ds = fake_random_ds(16) + add_noise_fields(ds) + + def create_image(filename_prefix): + my_sphere = ds.sphere("c", 1) + p = PhasePlot( + my_sphere, + ("gas", "noise1"), + ("gas", "noise3"), + [("gas", "density")], + weight_field=None, + ) + # using implicit units + p.set_zlim(("gas", "density"), zmax=10) + p.save(f"{filename_prefix}_set_zlim_implicit") + + # changing units should affect the colorbar and not the image + p.set_unit(("gas", "density"), "kg/AU**3") + p.save(f"{filename_prefix}_set_zlim_set_unit_implicit") + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_phaseplot_set_colorbar_properties_implicit" + test.answer_name = "phaseplot_set_colorbar_properties_implicit" + yield test diff --git a/yt/visualization/tests/test_norm_api_profileplot.py b/yt/visualization/tests/test_norm_api_profileplot.py new file mode 100644 index 0000000000..ee0e66ef41 --- /dev/null +++ b/yt/visualization/tests/test_norm_api_profileplot.py @@ -0,0 +1,29 @@ +from nose.plugins.attrib import attr + +from yt.testing import ANSWER_TEST_TAG, fake_random_ds +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import ProfilePlot + + +@attr(ANSWER_TEST_TAG) +def test_profileplot_set_axis_properties(): + ds = fake_random_ds(16) + + def create_image(filename_prefix): + disk = ds.disk(ds.domain_center, [0.0, 0.0, 1.0], (10, "m"), (1, "m")) + p = ProfilePlot(disk, ("gas", "density"), [("gas", "velocity_x")]) + p.save(f"{filename_prefix}_defaults") + + p.set_unit(("gas", "density"), "kg/cm**3") + p.save(f"{filename_prefix}_xunit") + + p.set_log(("gas", "density"), False) + p.save(f"{filename_prefix}_xunit_xlin") + + p.set_unit(("gas", "velocity_x"), "mile/hour") + p.save(f"{filename_prefix}_xunit_xlin_yunit") + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_profileplot_set_axis_properties" + test.answer_name = "profileplot_set_axis_properties" + yield test diff --git a/yt/visualization/tests/test_norm_api_set_background_color.py b/yt/visualization/tests/test_norm_api_set_background_color.py new file mode 100644 index 0000000000..a9d92b7247 --- /dev/null +++ b/yt/visualization/tests/test_norm_api_set_background_color.py @@ -0,0 +1,24 @@ +from nose.plugins.attrib import attr + +from yt.testing import ANSWER_TEST_TAG, fake_random_ds +from yt.utilities.answer_testing.framework import GenericImageTest +from yt.visualization.api import SlicePlot + + +@attr(ANSWER_TEST_TAG) +def test_sliceplot_set_background_color(): + # see https://github.com/yt-project/yt/issues/3854 + ds = fake_random_ds(16) + + def create_image(filename_prefix): + field = ("gas", "density") + p = SlicePlot(ds, "z", field, width=1.5) + p.set_background_color(field, color="C0") + p.save(f"{filename_prefix}_log") + p.set_log(("gas", "density"), False) + p.save(f"{filename_prefix}_lin") + + test = GenericImageTest(ds, create_image, 12) + test.prefix = "test_sliceplot_set_background_color" + test.answer_name = "sliceplot_set_background_color" + yield test diff --git a/yt/visualization/tests/test_norm_api_set_unit_and_zlim.py b/yt/visualization/tests/test_norm_api_set_unit_and_zlim.py new file mode 100644 index 0000000000..250e0f05b4 --- /dev/null +++ b/yt/visualization/tests/test_norm_api_set_unit_and_zlim.py @@ -0,0 +1,26 @@ +import numpy.testing as npt + +from yt.testing import fake_random_ds +from yt.visualization.api import SlicePlot + + +def test_sliceplot_set_unit_and_zlim_order(): + ds = fake_random_ds(16) + field = ("gas", "density") + + p0 = SlicePlot(ds, "z", field) + p0.set_unit(field, "kg/m**3") + p0.set_zlim(field, zmin=0) + + # reversing order of operations + p1 = SlicePlot(ds, "z", field) + p1.set_zlim(field, zmin=0) + p1.set_unit(field, "kg/m**3") + + p0._setup_plots() + p1._setup_plots() + + im0 = p0.plots[field].image._A + im1 = p1.plots[field].image._A + + npt.assert_allclose(im0, im1) diff --git a/yt/visualization/tests/test_particle_plot.py b/yt/visualization/tests/test_particle_plot.py index b6cf68d64d..d0037e48fd 100644 --- a/yt/visualization/tests/test_particle_plot.py +++ b/yt/visualization/tests/test_particle_plot.py @@ -42,7 +42,7 @@ def setup(): PROJ_ATTR_ARGS["set_log"] = [((("all", "particle_mass"), False), {})] PROJ_ATTR_ARGS["set_zlim"] = [ ((("all", "particle_mass"), 1e39, 1e42), {}), - ((("all", "particle_mass"), 1e39, None), {"dynamic_range": 4}), + ((("all", "particle_mass"),), {"zmin": 1e39, "dynamic_range": 4}), ] PHASE_ATTR_ARGS = { @@ -150,12 +150,17 @@ def formed_star(pfilter, data): ds.add_particle_filter("formed_star") for ax in "xyz": attr_name = "set_log" - for args in PROJ_ATTR_ARGS[attr_name]: - test = PlotWindowAttributeTest( - ds, plot_field, ax, attr_name, args, decimals, "ParticleProjectionPlot" - ) - test_particle_projection_filter.__name__ = test.description - yield test + test = PlotWindowAttributeTest( + ds, + plot_field, + ax, + attr_name, + ((plot_field, False), {}), + decimals, + "ParticleProjectionPlot", + ) + test_particle_projection_filter.__name__ = test.description + yield test @requires_ds(g30, big_data=True) diff --git a/yt/visualization/tests/test_plotwindow.py b/yt/visualization/tests/test_plotwindow.py index c599f347df..b2233c8cd0 100644 --- a/yt/visualization/tests/test_plotwindow.py +++ b/yt/visualization/tests/test_plotwindow.py @@ -5,10 +5,13 @@ from collections import OrderedDict import numpy as np +from matplotlib.colors import LogNorm, Normalize, SymLogNorm from nose.tools import assert_true +from unyt import unyt_array from yt.loaders import load_uniform_grid from yt.testing import ( + assert_allclose_units, assert_array_almost_equal, assert_array_equal, assert_equal, @@ -69,7 +72,7 @@ def setup(): "set_figure_size": [((7.0,), {})], "set_zlim": [ (("density", 1e-25, 1e-23), {}), - (("density", 1e-25, None), {"dynamic_range": 4}), + (("density",), {"zmin": 1e-25, "dynamic_range": 4}), ], "zoom": [((10,), {})], "toggle_right_handed": [((), {})], @@ -423,13 +426,13 @@ def setUp(self): fields_to_plot = fields + [("index", "radius")] if self.ds is None: self.ds = fake_random_ds(16, fields=fields, units=units) - self.slc = ProjectionPlot(self.ds, 0, fields_to_plot) + self.proj = ProjectionPlot(self.ds, 0, fields_to_plot) def tearDown(self): from yt.config import ytcfg del self.ds - del self.slc + del self.proj for key in self.newConfig.keys(): ytcfg.remove(*key) for key, val in self.oldConfig.items(): @@ -438,21 +441,37 @@ def tearDown(self): def test_units(self): from unyt import Unit - assert_equal(self.slc.frb["gas", "density"].units, Unit("mile*lb/yd**3")) - assert_equal(self.slc.frb["gas", "temperature"].units, Unit("cm*K")) - assert_equal(self.slc.frb["gas", "pressure"].units, Unit("dyn/cm")) + assert_equal(self.proj.frb["gas", "density"].units, Unit("mile*lb/yd**3")) + assert_equal(self.proj.frb["gas", "temperature"].units, Unit("cm*K")) + assert_equal(self.proj.frb["gas", "pressure"].units, Unit("dyn/cm")) def test_scale(self): - assert_equal(self.slc._field_transform["gas", "density"].name, "linear") - assert_equal(self.slc._field_transform["gas", "temperature"].name, "symlog") - assert_equal(self.slc._field_transform["gas", "temperature"].func, 100) - assert_equal(self.slc._field_transform["gas", "pressure"].name, "log10") - assert_equal(self.slc._field_transform["index", "radius"].name, "log10") + + assert_equal( + self.proj.plots["gas", "density"].norm_handler.norm_type, Normalize + ) + assert_equal( + self.proj.plots["gas", "temperature"].norm_handler.norm_type, SymLogNorm + ) + assert_allclose_units( + self.proj.plots["gas", "temperature"].norm_handler.linthresh, + unyt_array(100, "K*cm"), + ) + assert_equal(self.proj.plots["gas", "pressure"].norm_handler.norm_type, LogNorm) + assert_equal( + self.proj.plots["index", "radius"].norm_handler.norm_type, SymLogNorm + ) def test_cmap(self): - assert_equal(self.slc._colormap_config["gas", "density"], "plasma") - assert_equal(self.slc._colormap_config["gas", "temperature"], "hot") - assert_equal(self.slc._colormap_config["gas", "pressure"], "viridis") + assert_equal( + self.proj.plots["gas", "density"].colorbar_handler.cmap.name, "plasma" + ) + assert_equal( + self.proj.plots["gas", "temperature"].colorbar_handler.cmap.name, "hot" + ) + assert_equal( + self.proj.plots["gas", "pressure"].colorbar_handler.cmap.name, "viridis" + ) def test_on_off_compare(): @@ -718,7 +737,7 @@ def _neg_density(field, data): ("gas", "negative_density"), ]: plot = SlicePlot(ds, 2, field) - plot.set_log(field, True, linthresh=0.1) + plot.set_log(field, linthresh=0.1) with tempfile.NamedTemporaryFile(suffix="png") as f: plot.save(f.name) diff --git a/yt/visualization/tests/test_set_zlim.py b/yt/visualization/tests/test_set_zlim.py new file mode 100644 index 0000000000..00d41c0211 --- /dev/null +++ b/yt/visualization/tests/test_set_zlim.py @@ -0,0 +1,150 @@ +import numpy as np +import numpy.testing as npt +import pytest + +from yt._maintenance.deprecation import VisibleDeprecationWarning +from yt.testing import fake_amr_ds +from yt.visualization.api import SlicePlot + + +def test_float_vmin_then_set_unit(): + # this test doesn't represent how users should interact with plot containers + # in particular it uses the `_setup_plots()` private method, as a quick way to + # create a plot without having to make it an answer test + field = ("gas", "density") + ds = fake_amr_ds(fields=[field], units=["g/cm**3"]) + + p = SlicePlot(ds, "x", field) + p.set_buff_size(16) + + p._setup_plots() + cb = p.plots[field].image.colorbar + raw_lims = np.array((cb.vmin, cb.vmax)) + desired_lims = raw_lims.copy() + desired_lims[0] = 1e-2 + + p.set_zlim(field, zmin=desired_lims[0]) + + p._setup_plots() + cb = p.plots[field].image.colorbar + new_lims = np.array((cb.vmin, cb.vmax)) + npt.assert_almost_equal(new_lims, desired_lims) + + # 1 g/cm**3 == 1000 kg/m**3 + p.set_unit(field, "kg/m**3") + p._setup_plots() + + cb = p.plots[field].image.colorbar + new_lims = np.array((cb.vmin, cb.vmax)) + npt.assert_almost_equal(new_lims, 1000 * desired_lims) + + +def test_set_unit_then_float_vmin(): + field = ("gas", "density") + ds = fake_amr_ds(fields=[field], units=["g/cm**3"]) + + p = SlicePlot(ds, "x", field) + p.set_buff_size(16) + + p.set_unit(field, "kg/m**3") + p.set_zlim(field, zmin=1) + p._setup_plots() + cb = p.plots[field].image.colorbar + assert cb.vmin == 1.0 + + +def test_reset_zlim(): + field = ("gas", "density") + ds = fake_amr_ds(fields=[field], units=["g/cm**3"]) + + p = SlicePlot(ds, "x", field) + p.set_buff_size(16) + + p._setup_plots() + cb = p.plots[field].image.colorbar + raw_lims = np.array((cb.vmin, cb.vmax)) + + # set a new zmin value + delta = np.diff(raw_lims)[0] + p.set_zlim(field, zmin=raw_lims[0] + delta / 2) + + # passing "min" should restore default limit + p.set_zlim(field, zmin="min") + p._setup_plots() + + cb = p.plots[field].image.colorbar + new_lims = np.array((cb.vmin, cb.vmax)) + npt.assert_array_equal(new_lims, raw_lims) + + +def test_set_dynamic_range_with_vmin(): + field = ("gas", "density") + ds = fake_amr_ds(fields=[field], units=["g/cm**3"]) + + p = SlicePlot(ds, "x", field) + p.set_buff_size(16) + + zmin = 1e-2 + p.set_zlim(field, zmin=zmin, dynamic_range=2) + + p._setup_plots() + cb = p.plots[field].image.colorbar + new_lims = np.array((cb.vmin, cb.vmax)) + npt.assert_almost_equal(new_lims, (zmin, 2 * zmin)) + + +def test_set_dynamic_range_with_vmax(): + field = ("gas", "density") + ds = fake_amr_ds(fields=[field], units=["g/cm**3"]) + + p = SlicePlot(ds, "x", field) + p.set_buff_size(16) + + zmax = 1 + p.set_zlim(field, zmax=zmax, dynamic_range=2) + + p._setup_plots() + cb = p.plots[field].image.colorbar + new_lims = np.array((cb.vmin, cb.vmax)) + npt.assert_almost_equal(new_lims, (zmax / 2, zmax)) + + +def test_set_dynamic_range_with_min(): + field = ("gas", "density") + ds = fake_amr_ds(fields=[field], units=["g/cm**3"]) + + p = SlicePlot(ds, "x", field) + p.set_buff_size(16) + + p._setup_plots() + cb = p.plots[field].image.colorbar + vmin = cb.vmin + + p.set_zlim(field, zmin="min", dynamic_range=2) + + p._setup_plots() + cb = p.plots[field].image.colorbar + new_lims = np.array((cb.vmin, cb.vmax)) + npt.assert_almost_equal(new_lims, (vmin, 2 * vmin)) + + +def test_set_dynamic_range_with_None(): + field = ("gas", "density") + ds = fake_amr_ds(fields=[field], units=["g/cm**3"]) + + p = SlicePlot(ds, "x", field) + p.set_buff_size(16) + + p._setup_plots() + cb = p.plots[field].image.colorbar + vmin = cb.vmin + + with pytest.raises( + VisibleDeprecationWarning, match="Passing `zmin=None` explicitly is deprecated" + ): + p.set_zlim(field, zmin=None, dynamic_range=2) + + p._setup_plots() + cb = p.plots[field].image.colorbar + new_lims = np.array((cb.vmin, cb.vmax)) + npt.assert_almost_equal(new_lims, (vmin, 2 * vmin))