Skip to content

Commit

Permalink
Single matplotlib import (pydata#5794)
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan authored Oct 24, 2021
1 parent b791558 commit ea28861
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 42 deletions.
9 changes: 9 additions & 0 deletions asv_bench/benchmarks/import_xarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class ImportXarray:
def setup(self, *args, **kwargs):
def import_xr():
import xarray # noqa: F401

self._import_xr = import_xr

def time_import_xarray(self):
self._import_xr()
12 changes: 4 additions & 8 deletions xarray/plot/dataset_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
_process_cmap_cbar_kwargs,
get_axis,
label_from_attrs,
plt,
)

# copied from seaborn
Expand Down Expand Up @@ -134,8 +135,7 @@ def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None)

# copied from seaborn
def _parse_size(data, norm):

import matplotlib as mpl
mpl = plt.matplotlib

if data is None:
return None
Expand Down Expand Up @@ -544,8 +544,6 @@ def quiver(ds, x, y, ax, u, v, **kwargs):
Wraps :py:func:`matplotlib:matplotlib.pyplot.quiver`.
"""
import matplotlib as mpl

if x is None or y is None or u is None or v is None:
raise ValueError("Must specify x, y, u, v for quiver plots.")

Expand All @@ -560,7 +558,7 @@ def quiver(ds, x, y, ax, u, v, **kwargs):

# TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
if not cmap_params["norm"]:
cmap_params["norm"] = mpl.colors.Normalize(
cmap_params["norm"] = plt.Normalize(
cmap_params.pop("vmin"), cmap_params.pop("vmax")
)

Expand All @@ -576,8 +574,6 @@ def streamplot(ds, x, y, ax, u, v, **kwargs):
Wraps :py:func:`matplotlib:matplotlib.pyplot.streamplot`.
"""
import matplotlib as mpl

if x is None or y is None or u is None or v is None:
raise ValueError("Must specify x, y, u, v for streamplot plots.")

Expand Down Expand Up @@ -613,7 +609,7 @@ def streamplot(ds, x, y, ax, u, v, **kwargs):

# TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
if not cmap_params["norm"]:
cmap_params["norm"] = mpl.colors.Normalize(
cmap_params["norm"] = plt.Normalize(
cmap_params.pop("vmin"), cmap_params.pop("vmax")
)

Expand Down
10 changes: 2 additions & 8 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
_get_nice_quiver_magnitude,
_infer_xy_labels,
_process_cmap_cbar_kwargs,
import_matplotlib_pyplot,
label_from_attrs,
plt,
)

# Overrides axes.labelsize, xtick.major.size, ytick.major.size
Expand Down Expand Up @@ -116,8 +116,6 @@ def __init__(
"""

plt = import_matplotlib_pyplot()

# Handle corner case of nonunique coordinates
rep_col = col is not None and not data[col].to_index().is_unique
rep_row = row is not None and not data[row].to_index().is_unique
Expand Down Expand Up @@ -519,10 +517,8 @@ def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwar
self: FacetGrid object
"""
import matplotlib as mpl

if size is None:
size = mpl.rcParams["axes.labelsize"]
size = plt.rcParams["axes.labelsize"]

nicetitle = functools.partial(_nicetitle, maxchar=maxchar, template=template)

Expand Down Expand Up @@ -619,8 +615,6 @@ def map(self, func, *args, **kwargs):
self : FacetGrid object
"""
plt = import_matplotlib_pyplot()

for ax, namedict in zip(self.axes.flat, self.name_dicts.flat):
if namedict is not None:
data = self.data.loc[namedict]
Expand Down
8 changes: 1 addition & 7 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
_resolve_intervals_2dplot,
_update_axes,
get_axis,
import_matplotlib_pyplot,
label_from_attrs,
legend_elements,
plt,
)

# copied from seaborn
Expand Down Expand Up @@ -83,8 +83,6 @@ def _parse_size(data, norm, width):
If the data is categorical, normalize it to numbers.
"""
plt = import_matplotlib_pyplot()

if data is None:
return None

Expand Down Expand Up @@ -682,8 +680,6 @@ def scatter(
**kwargs : optional
Additional keyword arguments to matplotlib
"""
plt = import_matplotlib_pyplot()

# Handle facetgrids first
if row or col:
allargs = locals().copy()
Expand Down Expand Up @@ -1111,8 +1107,6 @@ def newplotfunc(
allargs["plotfunc"] = globals()[plotfunc.__name__]
return _easy_facetgrid(darray, kind="dataarray", **allargs)

plt = import_matplotlib_pyplot()

if (
plotfunc.__name__ == "surface"
and not kwargs.get("_is_facetgrid", False)
Expand Down
34 changes: 15 additions & 19 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def import_matplotlib_pyplot():
return plt


try:
plt = import_matplotlib_pyplot()
except ImportError:
plt = None


def _determine_extend(calc_data, vmin, vmax):
extend_min = calc_data.min() < vmin
extend_max = calc_data.max() > vmax
Expand All @@ -64,7 +70,7 @@ def _build_discrete_cmap(cmap, levels, extend, filled):
"""
Build a discrete colormap and normalization of the data.
"""
import matplotlib as mpl
mpl = plt.matplotlib

if len(levels) == 1:
levels = [levels[0], levels[0]]
Expand Down Expand Up @@ -115,8 +121,7 @@ def _build_discrete_cmap(cmap, levels, extend, filled):


def _color_palette(cmap, n_colors):
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
ListedColormap = plt.matplotlib.colors.ListedColormap

colors_i = np.linspace(0, 1.0, n_colors)
if isinstance(cmap, (list, tuple)):
Expand Down Expand Up @@ -177,7 +182,7 @@ def _determine_cmap_params(
cmap_params : dict
Use depends on the type of the plotting function
"""
import matplotlib as mpl
mpl = plt.matplotlib

if isinstance(levels, Iterable):
levels = sorted(levels)
Expand Down Expand Up @@ -285,13 +290,13 @@ def _determine_cmap_params(
levels = np.asarray([(vmin + vmax) / 2])
else:
# N in MaxNLocator refers to bins, not ticks
ticker = mpl.ticker.MaxNLocator(levels - 1)
ticker = plt.MaxNLocator(levels - 1)
levels = ticker.tick_values(vmin, vmax)
vmin, vmax = levels[0], levels[-1]

# GH3734
if vmin == vmax:
vmin, vmax = mpl.ticker.LinearLocator(2).tick_values(vmin, vmax)
vmin, vmax = plt.LinearLocator(2).tick_values(vmin, vmax)

if extend is None:
extend = _determine_extend(calc_data, vmin, vmax)
Expand Down Expand Up @@ -421,10 +426,7 @@ def _assert_valid_xy(darray, xy, name):


def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
try:
import matplotlib as mpl
import matplotlib.pyplot as plt
except ImportError:
if plt is None:
raise ImportError("matplotlib is required for plot.utils.get_axis")

if figsize is not None:
Expand All @@ -437,7 +439,7 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
if ax is not None:
raise ValueError("cannot provide both `size` and `ax` arguments")
if aspect is None:
width, height = mpl.rcParams["figure.figsize"]
width, height = plt.rcParams["figure.figsize"]
aspect = width / height
figsize = (size * aspect, size)
_, ax = plt.subplots(figsize=figsize)
Expand All @@ -454,9 +456,6 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):


def _maybe_gca(**kwargs):

import matplotlib.pyplot as plt

# can call gcf unconditionally: either it exists or would be created by plt.axes
f = plt.gcf()

Expand Down Expand Up @@ -912,9 +911,7 @@ def _process_cmap_cbar_kwargs(


def _get_nice_quiver_magnitude(u, v):
import matplotlib as mpl

ticker = mpl.ticker.MaxNLocator(3)
ticker = plt.MaxNLocator(3)
mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy()))
magnitude = ticker.tick_values(0, mean)[-2]
return magnitude
Expand Down Expand Up @@ -989,7 +986,7 @@ def legend_elements(
"""
import warnings

import matplotlib as mpl
mpl = plt.matplotlib

mlines = mpl.lines

Expand Down Expand Up @@ -1126,7 +1123,6 @@ def _legend_add_subtitle(handles, labels, text, func):

def _adjust_legend_subtitles(legend):
"""Make invisible-handle "subtitles" entries look more like titles."""
plt = import_matplotlib_pyplot()

# Legend title not in rcParams until 3.0
font_size = plt.rcParams.get("legend.title_fontsize", None)
Expand Down

0 comments on commit ea28861

Please sign in to comment.