Skip to content

Commit

Permalink
Add dataarray scatter (#6778)
Browse files Browse the repository at this point in the history
* allow adding any number of extra coords

* Explain how ds will becom darray

* Update dataset_plot.py

* Update dataset_plot.py

* use coords for coords

* Explain goal of moving ds plots to da

* Update dataset_plot.py

* Update dataset_plot.py

* Update dataset_plot.py

* handle non-existant coords

* Update dataset_plot.py

* Look through the kwargs to find extra coords

* output of legend labels has changed

* pop plt, comment out error test

* Update dataset_plot.py

* Update facetgrid.py

* move some funcs to utils

* add the funcs to the moved place

* various bugfixes

* use coords to check if valid
* only normalize sizes, hue is not necessary.
* Use same scatter parameter order as the dataset version.
* Fix tests assuming a list of patchollections is returned.

* improve ds to da wrapper

* Filter kwargs

* normalize args to be able to filter the correct args

* Update plot.py

* Update plot.py

* Update dataset_plot.py

* Some fixes to string colorbar

* Update plot.py

* Check if hue is str

* Fix some failing tests

* Update dataset_plot.py

* Add more relevant params higher up

* use hue in facetgrid, normalize data

* Update plot.py

* Move parts of scatter to a decorator

* Update plot.py

* Update plot.py

* get scatter to work with decorator

* use correct name

* Add a Normalize class

For categoricals to work most of the time a normalization to numerics has to be done. Once shown on the plot it has to be reformatted however with a lookup function

* skip use of Literal

* remove test code

* fix lint errors

* more linting fixes

* doctests fixing

* Update utils.py

* Update plot.py

* Update utils.py

* Update plot.py

* Update facetgrid.py

* revert some old ideas

* Update utils.py

* Update plot.py

* trim unused code

* use to_numpy instead

* more pint compats

* work on facetgrid legends

* facetgrid colorbar tweaks

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Categoricals starts on 1 and is bounded 0,2

This makes plt.colorbar return ticks in the center of the color

* Handle None in Normalize

* Fix labels

* Update plot.py

* determine guide

* fix plt

* Update facetgrid.py

* Don't be able to plot empty legends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* try out linecollection so lines behaves similar to scatter

* linecollections half working

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* Update plot.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* A few variations of linecollection

* linecollection can behave as scatter, with hue and size, But which part of the array will be considered a line and how do you filter for that?

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update plot.py

* line to utils

* line plot changes

* reshape to get hues working

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* line edits legend not nice on line plots yet

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update tutorial.py

* doc changes, tuple to dict

* nice line plots and working legend

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* comment out some variants

* some cleanup

* Guess some dims if they weren't defined

* None is supposed to pass as well

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make precommit happy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update plot.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add hist, step

* handle step using repeat, remove pint errors

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* handle pint

* fix some tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use isel instead to be independent of categoricals or not

* allow multiple primitives  and filter duplicates

* Update test_plot.py

* Copy data inside instead at init.

* Histograms has counted values along y, switch around  x and y labels.

* output as numpy array

* histogram outputs primitive only

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* Update facetgrid.py

* use add_labels inputs, explicit indexes now handles attrs

* colorbar in correct position

* Update plot.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Avoid always stacking

To avoid adding unnecessary NaNs.

* linecollection fixes

TODO is to make sure the values are plotted the along the same axis.

* Update plot.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add datarray scatter

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update plot.py

* Update plot.py

* out of scope stuff

* Update test_plot.py

* Update plot.py

* fix some tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* Update whats-new.rst

* Update utils.py

* Update xarray/plot/facetgrid.py

Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com>

* Update plot.py

* typo

* Apply suggestions from code review

Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com>

* Update xarray/plot/utils.py

Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com>

* Update plot.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update plot.py

* some typing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update facetgrid.py

* Update plot.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Convert name to string in label_from_attrs

* Update whats-new.rst

* Add typing to soem interval funcs

* undo parse_size edits, not necessary

* ax not needed

* Add some typing

* Update utils.py

* Cleaner retrieval of add_labels and

* type hints

* Fix facetgrid and normal plot not matching

* Update facetgrid.py

* Update plot.py

* Add typing to dataset funcs + some fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update dataset_plot.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add type hints to plot1d

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update facetgrid.py

* Update facetgrid.py

* Update facetgrid.py

* remove sharex for 3d plots, not supported. Add set_lims so all data in plots are shown

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update facetgrid.py

* Update facetgrid.py

* fix typing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Self should be any

* more fixes to typing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update facetgrid.py

* fix some mypy errors

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update plot.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update plot.py

* Update whats-new.rst

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com>
Co-authored-by: Anderson Banihirwe <axbanihirwe@ualr.edu>
  • Loading branch information
4 people authored Oct 7, 2022
1 parent 50ea159 commit 8dac64b
Show file tree
Hide file tree
Showing 6 changed files with 1,176 additions and 590 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ v2022.09.1 (unreleased)
New Features
~~~~~~~~~~~~

- Add scatter plot for datarrays. Scatter plots now also supports 3d plots with
the z argument. (:pull:`6778`)
By `Jimmy Westling <https://github.com/illviljan>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
201 changes: 106 additions & 95 deletions xarray/plot/dataset_plot.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,23 @@
from __future__ import annotations

import functools

import numpy as np
import pandas as pd
import inspect
from typing import TYPE_CHECKING, Any, Callable, Hashable, Mapping

from ..core.alignment import broadcast
from .facetgrid import _easy_facetgrid
from .plot import _PlotMethods
from .utils import (
_add_colorbar,
_get_nice_quiver_magnitude,
_infer_meta_data,
_parse_size,
_process_cmap_cbar_kwargs,
get_axis,
)


def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None):

broadcast_keys = ["x", "y"]
to_broadcast = [ds[x], ds[y]]
if hue:
to_broadcast.append(ds[hue])
broadcast_keys.append("hue")
if markersize:
to_broadcast.append(ds[markersize])
broadcast_keys.append("size")

broadcasted = dict(zip(broadcast_keys, broadcast(*to_broadcast)))

data = {"x": broadcasted["x"], "y": broadcasted["y"], "hue": None, "sizes": None}

if hue:
data["hue"] = broadcasted["hue"]

if markersize:
size = broadcasted["size"]

if size_mapping is None:
size_mapping = _parse_size(size, size_norm)

data["sizes"] = size.copy(
data=np.reshape(size_mapping.loc[size.values.ravel()].values, size.shape)
)

return data
if TYPE_CHECKING:
from ..core.dataarray import DataArray
from ..core.types import T_Dataset


class _Dataset_PlotMethods:
Expand Down Expand Up @@ -352,67 +324,6 @@ def plotmethod(
return newplotfunc


@_dsplot
def scatter(ds, x, y, ax, **kwargs):
"""
Scatter Dataset data variables against each other.
Wraps :py:func:`matplotlib:matplotlib.pyplot.scatter`.
"""

if "add_colorbar" in kwargs or "add_legend" in kwargs:
raise ValueError(
"Dataset.plot.scatter does not accept "
"'add_colorbar' or 'add_legend'. "
"Use 'add_guide' instead."
)

cmap_params = kwargs.pop("cmap_params")
hue = kwargs.pop("hue")
hue_style = kwargs.pop("hue_style")
markersize = kwargs.pop("markersize", None)
size_norm = kwargs.pop("size_norm", None)
size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid

# Remove `u` and `v` so they don't get passed to `ax.scatter`
kwargs.pop("u", None)
kwargs.pop("v", None)

# need to infer size_mapping with full dataset
data = _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping)

if hue_style == "discrete":
primitive = []
# use pd.unique instead of np.unique because that keeps the order of the labels,
# which is important to keep them in sync with the ones used in
# FacetGrid.add_legend
for label in pd.unique(data["hue"].values.ravel()):
mask = data["hue"] == label
if data["sizes"] is not None:
kwargs.update(s=data["sizes"].where(mask, drop=True).values.flatten())

primitive.append(
ax.scatter(
data["x"].where(mask, drop=True).values.flatten(),
data["y"].where(mask, drop=True).values.flatten(),
label=label,
**kwargs,
)
)

elif hue is None or hue_style == "continuous":
if data["sizes"] is not None:
kwargs.update(s=data["sizes"].values.ravel())
if data["hue"] is not None:
kwargs.update(c=data["hue"].values.ravel())

primitive = ax.scatter(
data["x"].values.ravel(), data["y"].values.ravel(), **cmap_params, **kwargs
)

return primitive


@_dsplot
def quiver(ds, x, y, ax, u, v, **kwargs):
"""Quiver plot of Dataset variables.
Expand Down Expand Up @@ -497,3 +408,103 @@ def streamplot(ds, x, y, ax, u, v, **kwargs):

# Return .lines so colorbar creation works properly
return hdl.lines


def _attach_to_plot_class(plotfunc: Callable) -> None:
"""
Set the function to the plot class and add a common docstring.
Use this decorator when relying on DataArray.plot methods for
creating the Dataset plot.
TODO: Reduce code duplication.
* The goal is to reduce code duplication by moving all Dataset
specific plots to the DataArray side and use this thin wrapper to
handle the conversion between Dataset and DataArray.
* Improve docstring handling, maybe reword the DataArray versions to
explain Datasets better.
* Consider automatically adding all _PlotMethods to
_Dataset_PlotMethods.
Parameters
----------
plotfunc : function
Function that returns a finished plot primitive.
"""
# Build on the original docstring:
original_doc = getattr(_PlotMethods, plotfunc.__name__, object)
commondoc = original_doc.__doc__
if commondoc is not None:
doc_warning = (
f"This docstring was copied from xr.DataArray.plot.{original_doc.__name__}."
" Some inconsistencies may exist."
)
# Add indentation so it matches the original doc:
commondoc = f"\n\n {doc_warning}\n\n {commondoc}"
else:
commondoc = ""
plotfunc.__doc__ = (
f" {plotfunc.__doc__}\n\n"
" The `y` DataArray will be used as base,"
" any other variables are added as coords.\n\n"
f"{commondoc}"
)

@functools.wraps(plotfunc)
def plotmethod(self, *args, **kwargs):
return plotfunc(self._ds, *args, **kwargs)

# Add to class _PlotMethods
setattr(_Dataset_PlotMethods, plotmethod.__name__, plotmethod)


def _normalize_args(plotmethod: str, args, kwargs) -> dict[str, Any]:
from ..core.dataarray import DataArray

# Determine positional arguments keyword by inspecting the
# signature of the plotmethod:
locals_ = dict(
inspect.signature(getattr(DataArray().plot, plotmethod))
.bind(*args, **kwargs)
.arguments.items()
)
locals_.update(locals_.pop("kwargs", {}))

return locals_


def _temp_dataarray(ds: T_Dataset, y: Hashable, locals_: Mapping) -> DataArray:
"""Create a temporary datarray with extra coords."""
from ..core.dataarray import DataArray

# Base coords:
coords = dict(ds.coords)

# Add extra coords to the DataArray from valid kwargs, if using all
# kwargs there is a risk that we add unneccessary dataarrays as
# coords straining RAM further for example:
# ds.both and extend="both" would add ds.both to the coords:
valid_coord_kwargs = {"x", "z", "markersize", "hue", "row", "col", "u", "v"}
coord_kwargs = locals_.keys() & valid_coord_kwargs
for k in coord_kwargs:
key = locals_[k]
if ds.data_vars.get(key) is not None:
coords[key] = ds[key]

# The dataarray has to include all the dims. Broadcast to that shape
# and add the additional coords:
_y = ds[y].broadcast_like(ds)

return DataArray(_y, coords=coords)


@_attach_to_plot_class
def scatter(ds: T_Dataset, x: Hashable, y: Hashable, *args, **kwargs):
"""Scatter plot Dataset data variables against each other."""
plotmethod = "scatter"
kwargs.update(x=x)
locals_ = _normalize_args(plotmethod, args, kwargs)
da = _temp_dataarray(ds, y, locals_)

return getattr(da.plot, plotmethod)(*locals_.pop("args", ()), **locals_)
Loading

0 comments on commit 8dac64b

Please sign in to comment.