Skip to content

Commit

Permalink
Merge pull request #4089 from neutrinoceros/hotfix_depr_colormap_api
Browse files Browse the repository at this point in the history
BUG: avoid using deprecated matplotlib.cm API
  • Loading branch information
neutrinoceros authored Aug 22, 2022
2 parents f5578fc + 638d074 commit 8beb2d2
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 22 deletions.
4 changes: 2 additions & 2 deletions doc/source/visualizing/colormaps/cmap_images.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import matplotlib.cm as cm
import matplotlib as mpl

import yt

Expand All @@ -8,7 +8,7 @@
# Create projections using each colormap available.
p = yt.ProjectionPlot(ds, "z", "density", weight_field="density", width=0.4)

for cmap in cm.datad:
for cmap in mpl.colormaps:
if cmap.startswith("idl"):
continue
p.set_cmap(field="density", cmap=cmap)
Expand Down
39 changes: 29 additions & 10 deletions yt/visualization/color_maps.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from copy import deepcopy
from typing import Tuple, Union
from typing import Optional, Tuple, Union

import cmyt # noqa: F401
import matplotlib as mpl
import numpy as np
from matplotlib import cm as mcm, colors as cc
from matplotlib.colors import Colormap, LinearSegmentedColormap
from packaging.version import Version

from yt.funcs import get_brewer_cmap
Expand All @@ -12,6 +13,24 @@
from . import _colormap_data as _cm
from ._commons import MPL_VERSION


# wrap matplotlib.cm API, use non-deprecated API when available
def _get_cmap(name: str) -> Colormap:
if MPL_VERSION >= Version("3.5"):
return mpl.colormaps[name]
else:
# deprecated API
return mpl.cm.get_cmap(name)


def _register_cmap(cmap: Colormap, *, name: Optional[str] = None) -> None:
if MPL_VERSION >= Version("3.5"):
mpl.colormaps.register(cmap, name=name)
else:
# deprecated API
mpl.cm.register_cmap(name=name, cmap=cmap)


yt_colormaps = {}


Expand All @@ -20,8 +39,8 @@ def add_colormap(name, cdict):
Adds a colormap to the colormaps available in yt for this session
"""
# Note: this function modifies the global variable 'yt_colormaps'
yt_colormaps[name] = cc.LinearSegmentedColormap(name, cdict, 256)
mcm.register_cmap(name, yt_colormaps[name])
yt_colormaps[name] = LinearSegmentedColormap(name, cdict, 256)
_register_cmap(yt_colormaps[name], name=name)


# YTEP-0040 backward compatibility layer
Expand Down Expand Up @@ -51,13 +70,13 @@ def register_yt_colormaps_from_cmyt():

for hist_name, alias in _HISTORICAL_ALIASES.items():
if MPL_VERSION >= Version("3.4.0"):
cmap = mcm.get_cmap(alias).copy()
cmap = _get_cmap(alias).copy()
else:
cmap = deepcopy(mcm.get_cmap(alias))
cmap = deepcopy(_get_cmap(alias))
cmap.name = hist_name
try:
mcm.register_cmap(cmap=cmap)
mcm.register_cmap(cmap=mcm.get_cmap(hist_name).reversed())
_register_cmap(cmap=cmap)
_register_cmap(cmap=_get_cmap(hist_name).reversed())
except ValueError:
# Matplotlib 3.4.0 hard-forbids name collisions, but more recent versions
# will emit a warning instead, so we emulate this behaviour regardless.
Expand Down Expand Up @@ -95,7 +114,7 @@ def get_colormap_lut(cmap_id: Union[Tuple[str, str], str]):
if isinstance(cmap_id, tuple) and len(cmap_id) == 2:
cmap = get_brewer_cmap(cmap_id)
elif isinstance(cmap_id, str):
cmap = mcm.get_cmap(cmap_id)
cmap = _get_cmap(cmap_id)
else:
raise TypeError(
"Expected a string or a 2-tuple of strings as a colormap id. "
Expand Down Expand Up @@ -182,7 +201,7 @@ def show_colormaps(subset="all", filename=None):
for i, m in enumerate(maps):
plt.subplot(1, l, i + 1)
plt.axis("off")
plt.imshow(a, aspect="auto", cmap=mcm.get_cmap(m), origin="lower")
plt.imshow(a, aspect="auto", cmap=_get_cmap(m), origin="lower")
plt.title(m, rotation=90, fontsize=10, verticalalignment="bottom")
if filename is not None:
plt.savefig(filename, dpi=100, facecolor="gray")
Expand Down
4 changes: 2 additions & 2 deletions yt/visualization/eps_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
import pyx
from matplotlib import cm, pyplot as plt
from matplotlib import pyplot as plt

from yt.config import ytcfg
from yt.units.unit_object import Unit # type: ignore
Expand Down Expand Up @@ -744,7 +744,7 @@ def colorbar(

# Convert the colormap into a string
x = np.linspace(1, 0, 256)
cm_string = cm.get_cmap(name)(x, bytes=True)[:, 0:3].tobytes()
cm_string = plt.get_cmap(name)(x, bytes=True)[:, 0:3].tobytes()

cmap_im = pyx.bitmap.image(imsize[0], imsize[1], "RGB", cm_string)
if orientation == "top" or orientation == "bottom":
Expand Down
5 changes: 3 additions & 2 deletions yt/visualization/plot_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
from matplotlib.cm import get_cmap
from matplotlib.font_manager import FontProperties
from more_itertools.more import always_iterable

Expand Down Expand Up @@ -916,9 +915,11 @@ def set_background_color(self, field, color=None):
"""
if color is None:
from yt.visualization.color_maps import _get_cmap

cmap = self._colormap_config[field]
if isinstance(cmap, str):
cmap = get_cmap(cmap)
cmap = _get_cmap(cmap)
color = cmap(0)
self._background_color[field] = color
return self
Expand Down
3 changes: 2 additions & 1 deletion yt/visualization/tests/test_splat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import yt
from yt.testing import assert_equal
from yt.utilities.lib.api import add_rgba_points_to_image # type: ignore
from yt.visualization.color_maps import _get_cmap


def setup():
Expand All @@ -30,7 +31,7 @@ def test_splat():
xs = prng.random_sample(Np)
ys = prng.random_sample(Np)

cbx = yt.visualization.color_maps.mcm.RdBu
cbx = _get_cmap("RdBu")
cs = cbx(prng.random_sample(Np))
add_rgba_points_to_image(image, xs, ys, cs)

Expand Down
5 changes: 3 additions & 2 deletions yt/visualization/volume_rendering/image_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ def plot_channel(
specified.
"""
from matplotlib import pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib.colors import LogNorm

from yt.visualization.color_maps import _get_cmap

Nvec = image.shape[0]
image[np.isnan(image)] = 0.0
ma = image[image > 0.0].max()
Expand All @@ -94,7 +95,7 @@ def plot_channel(
fig.subplots_adjust(
left=0.0, right=1.0, bottom=0.0, top=1.0, wspace=0.0, hspace=0.0
)
mycm = get_cmap(cmap)
mycm = _get_cmap(cmap)
if log:
ax.imshow(image, cmap=mycm, norm=mynorm, interpolation="nearest")
else:
Expand Down
9 changes: 6 additions & 3 deletions yt/visualization/volume_rendering/transfer_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from matplotlib.cm import get_cmap
from more_itertools import always_iterable

from yt.funcs import mylog
Expand Down Expand Up @@ -728,12 +727,14 @@ def sample_colormap(self, v, w, alpha=None, colormap="gist_stern", col_bounds=No
>>> tf = ColorTransferFunction((-10.0, -5.0))
>>> tf.sample_colormap(-7.0, 0.01, colormap="cmyt.arbre")
"""
from yt.visualization.color_maps import _get_cmap

v = np.float64(v)
if col_bounds is None:
rel = (v - self.x_bounds[0]) / (self.x_bounds[1] - self.x_bounds[0])
else:
rel = (v - col_bounds[0]) / (col_bounds[1] - col_bounds[0])
cmap = get_cmap(colormap)
cmap = _get_cmap(colormap)
r, g, b, a = cmap(rel)
if alpha is None:
alpha = a
Expand Down Expand Up @@ -780,6 +781,8 @@ def map_to_colormap(
... -6.0, -5.0, scale=10.0, colormap="cmyt.arbre", scale_func=linramp
... )
"""
from yt.visualization.color_maps import _get_cmap

mi = np.float64(mi)
ma = np.float64(ma)
rel0 = int(
Expand All @@ -791,7 +794,7 @@ def map_to_colormap(
rel0 = max(rel0, 0)
rel1 = min(rel1, self.nbins - 1) + 1
tomap = np.linspace(0.0, 1.0, num=rel1 - rel0)
cmap = get_cmap(colormap)
cmap = _get_cmap(colormap)
cc = cmap(tomap)
if scale_func is None:
scale_mult = 1.0
Expand Down

0 comments on commit 8beb2d2

Please sign in to comment.