Skip to content

Commit

Permalink
Address inf_as_na pandas deprecation (#3424)
Browse files Browse the repository at this point in the history
* Address inf_as_na pandas deprecation

* Add -np.inf, add import

* flake8

* Make copy

* Use mask instead of replace
  • Loading branch information
mroeschke authored Aug 19, 2023
1 parent aebf7d8 commit 2386036
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 22 deletions.
14 changes: 7 additions & 7 deletions seaborn/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,13 +1121,13 @@ def comp_data(self):
parts = []
grouped = self.plot_data[var].groupby(self.converters[var], sort=False)
for converter, orig in grouped:
with pd.option_context('mode.use_inf_as_na', True):
orig = orig.dropna()
if var in self.var_levels:
# TODO this should happen in some centralized location
# it is similar to GH2419, but more complicated because
# supporting `order` in categorical plots is tricky
orig = orig[orig.isin(self.var_levels[var])]
orig = orig.mask(orig.isin([np.inf, -np.inf]), np.nan)
orig = orig.dropna()
if var in self.var_levels:
# TODO this should happen in some centralized location
# it is similar to GH2419, but more complicated because
# supporting `order` in categorical plots is tricky
orig = orig[orig.isin(self.var_levels[var])]
comp = pd.to_numeric(converter.convert_units(orig)).astype(float)
if converter.get_scale() == "log":
comp = np.log10(comp)
Expand Down
34 changes: 19 additions & 15 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from matplotlib.axes import Axes
from matplotlib.artist import Artist
from matplotlib.figure import Figure
import numpy as np
from PIL import Image

from seaborn._marks.base import Mark
Expand Down Expand Up @@ -1587,21 +1588,24 @@ def split_generator(keep_na=False) -> Generator:

axes_df = self._filter_subplot_data(df, view)

with pd.option_context("mode.use_inf_as_na", True):
if keep_na:
# The simpler thing to do would be x.dropna().reindex(x.index).
# But that doesn't work with the way that the subset iteration
# is written below, which assumes data for grouping vars.
# Matplotlib (usually?) masks nan data, so this should "work".
# Downstream code can also drop these rows, at some speed cost.
present = axes_df.notna().all(axis=1)
nulled = {}
for axis in "xy":
if axis in axes_df:
nulled[axis] = axes_df[axis].where(present)
axes_df = axes_df.assign(**nulled)
else:
axes_df = axes_df.dropna()
axes_df_inf_as_nan = axes_df.copy()
axes_df_inf_as_nan = axes_df_inf_as_nan.mask(
axes_df_inf_as_nan.isin([np.inf, -np.inf]), np.nan
)
if keep_na:
# The simpler thing to do would be x.dropna().reindex(x.index).
# But that doesn't work with the way that the subset iteration
# is written below, which assumes data for grouping vars.
# Matplotlib (usually?) masks nan data, so this should "work".
# Downstream code can also drop these rows, at some speed cost.
present = axes_df_inf_as_nan.notna().all(axis=1)
nulled = {}
for axis in "xy":
if axis in axes_df:
nulled[axis] = axes_df[axis].where(present)
axes_df = axes_df_inf_as_nan.assign(**nulled)
else:
axes_df = axes_df_inf_as_nan.dropna()

subplot_keys = {}
for dim in ["col", "row"]:
Expand Down

0 comments on commit 2386036

Please sign in to comment.