Skip to content

Commit

Permalink
ENH: add support for symlog colorbars with arbitrary bases (requires …
Browse files Browse the repository at this point in the history
…MPL>=3.5)
  • Loading branch information
neutrinoceros committed May 28, 2023
1 parent 6c3bc26 commit e8bd650
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 96 deletions.
82 changes: 22 additions & 60 deletions yt/visualization/_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import matplotlib as mpl
import numpy as np
from matplotlib.ticker import SymmetricalLogLocator
from more_itertools import always_iterable
from packaging.version import Version

Expand Down Expand Up @@ -229,66 +230,6 @@ def _swap_arg_pair_order(*args):
return tuple(new_args)


def get_log_minorticks(vmin: float, vmax: float) -> np.ndarray:
"""calculate positions of linear minorticks on a log colorbar
Parameters
----------
vmin : float
the minimum value in the colorbar
vmax : float
the maximum value in the colorbar
"""
expA = np.floor(np.log10(vmin))
expB = np.floor(np.log10(vmax))
cofA = np.ceil(vmin / 10**expA).astype("int64")
cofB = np.floor(vmax / 10**expB).astype("int64")
lmticks = np.empty(0)
while cofA * 10**expA <= cofB * 10**expB:
if expA < expB:
lmticks = np.hstack((lmticks, np.linspace(cofA, 9, 10 - cofA) * 10**expA))
cofA = 1
expA += 1
else:
lmticks = np.hstack(
(lmticks, np.linspace(cofA, cofB, cofB - cofA + 1) * 10**expA)
)
expA += 1
return np.array(lmticks)


def get_symlog_minorticks(linthresh: float, vmin: float, vmax: float) -> np.ndarray:
"""calculate positions of linear minorticks on a symmetric log colorbar
Parameters
----------
linthresh : float
the threshold for the linear region
vmin : float
the minimum value in the colorbar
vmax : float
the maximum value in the colorbar
"""
if vmin > 0:
return get_log_minorticks(vmin, vmax)
elif vmax < 0 and vmin < 0:
return -get_log_minorticks(-vmax, -vmin)
elif vmin == 0:
return np.hstack((0, get_log_minorticks(linthresh, vmax)))
elif vmax == 0:
return np.hstack((-get_log_minorticks(linthresh, -vmin)[::-1], 0))
else:
return np.hstack(
(
-get_log_minorticks(linthresh, -vmin)[::-1],
0,
get_log_minorticks(linthresh, vmax),
)
)


def get_symlog_majorticks(linthresh: float, vmin: float, vmax: float) -> np.ndarray:
"""calculate positions of major ticks on a log colorbar
Expand All @@ -302,6 +243,9 @@ def get_symlog_majorticks(linthresh: float, vmin: float, vmax: float) -> np.ndar
the maximum value in the colorbar
"""
if MPL_VERSION >= Version("3.5"):
raise RuntimeError("get_symlog_majorticks is not needed with matplotlib>=3.5")

if vmin >= 0.0:
yticks = [vmin] + list(
10
Expand Down Expand Up @@ -352,6 +296,24 @@ def get_symlog_majorticks(linthresh: float, vmin: float, vmax: float) -> np.ndar
return np.array(yticks)


class _MPL38_SymmetricalLogLocator(SymmetricalLogLocator):
# Backporting behaviour from matplotlib 3.8 (in development at the time of writing)
# see https://github.com/matplotlib/matplotlib/pull/25970

def tick_values(self, vmin, vmax):
if MPL_VERSION >= Version("3.8"):
return super().tick_values(vmin, vmax)

linthresh = self._linthresh
if vmax < vmin:
vmin, vmax = vmax, vmin
if -linthresh <= vmin < vmax <= linthresh:
# only the linear range is present
return sorted({vmin, 0, vmax})

return super().tick_values(vmin, vmax)


def get_default_from_config(data_source, *, field, keys, defaults):
_keys = list(always_iterable(keys))
_defaults = list(always_iterable(defaults))
Expand Down
100 changes: 64 additions & 36 deletions yt/visualization/base_plot_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
MPL_VERSION,
get_canvas,
get_symlog_majorticks,
get_symlog_minorticks,
validate_image_name,
)

if MPL_VERSION >= Version("3.8"):
from matplotlib.ticker import SymmetricalLogLocator
else:
from ._commons import _MPL38_SymmetricalLogLocator as SymmetricalLogLocator

if TYPE_CHECKING:
from matplotlib.axis import Axis
from matplotlib.figure import Figure
Expand Down Expand Up @@ -311,6 +315,13 @@ def _init_image(self, data, extent, aspect):
self._set_axes(norm)

def _set_axes(self, norm: Normalize) -> None:
if MPL_VERSION >= Version("3.5"):
self._set_axes_mpl_ge35()
else:
self._set_axes_mpl_lt35(norm)

def _set_axes_mpl_lt35(self, norm: Normalize) -> None:
# bug-for-bug backward-compatibility for matplotlib older than 3.5
if isinstance(norm, SymLogNorm):
formatter = LogFormatterMathtext(linthresh=norm.linthresh)
self.cb = self.figure.colorbar(self.image, self.cax, format=formatter)
Expand All @@ -328,48 +339,65 @@ def _set_axes(self, norm: Normalize) -> None:
if type(norm) not in (LogNorm, SymLogNorm):
try:
self.cb.ax.ticklabel_format(**fmt_kwargs)
except AttributeError as exc:
if MPL_VERSION < Version("3.5.0"):
warnings.warn(
"Failed to format colorbar ticks. "
"This is expected when using the set_norm method "
"with some matplotlib classes (e.g. TwoSlopeNorm) "
"with matplotlib versions older than 3.5\n"
"Please try upgrading matplotlib to a more recent version. "
"If the problem persists, please file a report to "
"https://github.com/yt-project/yt/issues/new",
stacklevel=2,
)
else:
raise exc
if self.colorbar_handler.draw_minorticks:
if isinstance(norm, SymLogNorm):
if MPL_VERSION < Version("3.5.0b"):
# no known working method to draw symlog minor ticks
# see https://github.com/yt-project/yt/issues/3535
pass
else:
flinthresh = 10 ** np.floor(np.log10(norm.linthresh))
absmax = np.abs((norm.vmin, norm.vmax)).max()
if (absmax - flinthresh) / absmax < 0.1:
flinthresh /= 10
mticks = get_symlog_minorticks(flinthresh, norm.vmin, norm.vmax)
if MPL_VERSION < Version("3.5.0b"):
# https://github.com/matplotlib/matplotlib/issues/21258
mticks = self.image.norm(mticks)
self.cax.yaxis.set_ticks(mticks, minor=True)

elif isinstance(norm, LogNorm):
self.cax.minorticks_on()
self.cax.xaxis.set_visible(False)
except AttributeError:
warnings.warn(
"Failed to format colorbar ticks. "
"This is expected when using the set_norm method "
"with some matplotlib classes (e.g. TwoSlopeNorm) "
"with matplotlib versions older than 3.5\n"
"Please try upgrading matplotlib to a more recent version. "
"If the problem persists, please file a report to "
"https://github.com/yt-project/yt/issues/new",
stacklevel=2,
)

else:
if self.colorbar_handler.draw_minorticks:
if not isinstance(norm, SymLogNorm):
# no known working method to draw symlog minor ticks
# see https://github.com/yt-project/yt/issues/3535
self.cax.minorticks_on()
else:
self.cax.minorticks_off()

self.image.axes.set_facecolor(self.colorbar_handler.background_color)

def _set_axes_mpl_ge35(self) -> None:
fmt_kwargs = {"style": "scientific", "scilimits": (-2, 3), "useMathText": True}
self.image.axes.ticklabel_format(**fmt_kwargs)

self.cax.tick_params(which="both", direction="in")
self.cb = self.figure.colorbar(self.image, self.cax)

if self.cb.orientation == "vertical":
cb_axis = self.cb.ax.yaxis
else:
cb_axis = self.cb.ax.xaxis

cb_scale = cb_axis.get_scale()
if cb_scale == "symlog":
trf = cb_axis.get_transform()
cb_axis.set_major_locator(SymmetricalLogLocator(trf))
cb_axis.set_major_formatter(
LogFormatterMathtext(linthresh=trf.linthresh, base=trf.base)
)

if cb_scale not in ("log", "symlog"):
self.cb.ax.ticklabel_format(**fmt_kwargs)

if self.colorbar_handler.draw_minorticks and cb_scale == "symlog":
# no minor ticks are drawn by default in symlog, as of matplotlib 3.7.1
# see https://github.com/matplotlib/matplotlib/issues/25994
trf = cb_axis.get_transform()
if float(trf.base).is_integer():
locator = SymmetricalLogLocator(trf, subs=np.arange(1, trf.base))
cb_axis.set_minor_locator(locator)
elif self.colorbar_handler.draw_minorticks:
self.cb.minorticks_on()
else:
self.cb.minorticks_off()

self.image.axes.set_facecolor(self.colorbar_handler.background_color)

def _validate_axes_extent(self, extent, transform):
# if the axes are cartopy GeoAxes, this checks that the axes extent
# is properly set.
Expand Down
20 changes: 20 additions & 0 deletions yt/visualization/tests/test_image_comp_2D_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import numpy.testing as npt
import pytest
from matplotlib.colors import SymLogNorm
from packaging.version import Version

from yt.data_objects.profiles import create_profile
Expand Down Expand Up @@ -91,6 +92,25 @@ def test_sliceplot_custom_norm():
return p.plots[field].figure


@pytest.mark.skipif(
MPL_VERSION < Version("3.5"),
reason=f"Correct behaviour requires MPL 3.5 we have {MPL_VERSION}",
)
@pytest.mark.mpl_image_compare
def test_sliceplot_custom_norm_symlog2():
ds = fake_random_ds(16)
add_noise_fields(ds)
field = "noise3"
p = SlicePlot(ds, "z", field)

# using integer base !=10 and >2 to exercise special case
# for colorbar minor ticks
p.set_norm(field, norm=SymLogNorm(linthresh=0.1, base=5))

p.render()
return p.plots[field].figure


@pytest.mark.skipif(
MPL_VERSION >= Version("3.5"),
reason=f"Testing a warning that should only happen with MPL < 3.5, we have {MPL_VERSION}",
Expand Down

0 comments on commit e8bd650

Please sign in to comment.