Skip to content

Commit

Permalink
ENH: add support for symlog colorbars with arbitrary bases (requires …
Browse files Browse the repository at this point in the history
…MPL>=3.5)
  • Loading branch information
neutrinoceros committed May 21, 2023
1 parent 083ff1d commit 6f5084e
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 20 deletions.
57 changes: 56 additions & 1 deletion yt/visualization/_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import warnings
from functools import wraps
from importlib.metadata import version
from typing import TYPE_CHECKING, Optional, Type, TypeVar
from itertools import chain
from typing import TYPE_CHECKING, Optional, Sequence, Type, TypeVar

import matplotlib as mpl
import numpy as np
from matplotlib.ticker import SymmetricalLogLocator
from more_itertools import always_iterable
from packaging.version import Version

Expand Down Expand Up @@ -358,6 +360,59 @@ def get_symlog_majorticks(linthresh: float, vmin: float, vmax: float) -> np.ndar
return np.array(yticks)


class SymmetricalLogIntMinorLocator(SymmetricalLogLocator):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
if not float(self._base).is_integer():
raise ValueError(
"SymmetricalLogIntMinorLocator requires an integer log base, "
f"got {self._base=}"
)
self._base = int(self._base)

def tick_values(self, vmin: float, vmax: float) -> Sequence[float]:
base = self._base
linthresh = self._linthresh

if vmax < vmin:
vmin, vmax = vmax, vmin

def intlog_pos_range(vmin: float, vmax: float, base: int):
expA = np.floor(np.log(vmin) / np.log(base))
expB = np.floor(np.log(vmax) / np.log(base))
cofA = int(np.ceil(vmin / base**expA))
cofB = int(np.floor(vmax / base**expB))
lmticks = []
while cofA * base**expA <= cofB * base**expB:
if expA < expB:
lmticks.append(
np.linspace(cofA, base - 1, base - cofA) * base**expA
)
cofA = 1
else:
lmticks.append(
np.linspace(cofA, cofB, cofB - cofA + 1) * base**expA
)
expA += 1
return np.array(list(chain.from_iterable(lmticks)))

if vmin > 0:
return intlog_pos_range(vmin, vmax, base)
elif vmax < 0:
return -intlog_pos_range(-vmax, -vmin, base)
elif vmin == 0:
return intlog_pos_range(linthresh, vmax, base)
elif vmax == 0:
return -intlog_pos_range(linthresh, -vmin, base)[::-1]
else:
return np.hstack(
(
-intlog_pos_range(linthresh, -vmin, base)[::-1],
intlog_pos_range(linthresh, vmax, base),
)
)


def get_default_from_config(data_source, *, field, keys, defaults):
_keys = list(always_iterable(keys))
_defaults = list(always_iterable(defaults))
Expand Down
55 changes: 36 additions & 19 deletions yt/visualization/base_plot_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

from ._commons import (
MPL_VERSION,
SymmetricalLogIntMinorLocator,
get_canvas,
get_symlog_majorticks,
get_symlog_minorticks,
validate_image_name,
)

Expand Down Expand Up @@ -312,15 +312,26 @@ def _init_image(self, data, extent, aspect):

def _set_axes(self, norm: Normalize) -> None:
if isinstance(norm, SymLogNorm):
formatter = LogFormatterMathtext(linthresh=norm.linthresh)
self.cb = self.figure.colorbar(self.image, self.cax, format=formatter)
if MPL_VERSION >= Version("3.5"):
logbase = norm._trf.base
else:
# bug-for-bug backward-compatibility for older versions of matplotlib
# at the risk of being incorrect in case the norm was user-defined
logbase = 10

formatter = LogFormatterMathtext(linthresh=norm.linthresh, base=logbase)
else:
formatter = None

self.cb = self.figure.colorbar(self.image, self.cax, format=formatter)

if isinstance(norm, SymLogNorm) and MPL_VERSION < Version("3.5"):
self.cb.set_ticks(
get_symlog_majorticks(
linthresh=norm.linthresh, vmin=norm.vmin, vmax=norm.vmax
)
)
else:
self.cb = self.figure.colorbar(self.image, self.cax)

self.cax.tick_params(which="both", direction="in")

fmt_kwargs = {"style": "scientific", "scilimits": (-2, 3), "useMathText": True}
Expand All @@ -342,22 +353,28 @@ def _set_axes(self, norm: Normalize) -> None:
)
else:
raise exc
if self.colorbar_handler.draw_minorticks:
if isinstance(norm, SymLogNorm):
if MPL_VERSION >= Version("3.5.0b"):
self.cb.set_ticks(
get_symlog_minorticks(norm.linthresh, norm.vmin, norm.vmax),
minor=True,
)
else:
# no known working method to draw symlog minor ticks
# see https://github.com/yt-project/yt/issues/3535
# and https://github.com/matplotlib/matplotlib/issues/21258
pass

if not self.colorbar_handler.draw_minorticks:
self.cax.minorticks_off()
elif MPL_VERSION >= Version("3.5"):
if self.cb.orientation == "vertical":
axis = self.cb.ax.yaxis
else:
self.cax.minorticks_on()
axis = self.cb.ax.xaxis
s = axis._scale
if axis.get_scale() == "symlog" and float(s._transform.base).is_integer():
axis.set_minor_locator(
SymmetricalLogIntMinorLocator(s._transform, s.subs)
)
else:
self.cb.minorticks_on()
elif isinstance(norm, SymLogNorm):
# no known working method to draw symlog minor ticks
# see https://github.com/yt-project/yt/issues/3535
# and https://github.com/matplotlib/matplotlib/issues/21258
pass
else:
self.cax.minorticks_off()
self.cb.minorticks_on()

self.image.axes.set_facecolor(self.colorbar_handler.background_color)

Expand Down

0 comments on commit 6f5084e

Please sign in to comment.