Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing to plot methods #7052

Merged
merged 59 commits into from
Oct 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
f22ffc7
add plot methods statically and add typing to plot tests
headtr1ck Sep 18, 2022
bef6b7a
whats-new update
headtr1ck Sep 18, 2022
fd12592
fix copy-paste typo
headtr1ck Sep 18, 2022
85b59fc
correct plot signatures
headtr1ck Sep 19, 2022
5a0b8dc
add *some* typing to plot methods
headtr1ck Sep 20, 2022
7edde8f
annotate darray in plot tests
headtr1ck Sep 20, 2022
5f5fffc
correct typing of plot returns
headtr1ck Sep 22, 2022
89d033b
Merge branch 'main' into plotaccessor
headtr1ck Sep 22, 2022
60c6a70
fix plotting overloads
headtr1ck Sep 23, 2022
481565c
add correct overloads to dataset_plot
headtr1ck Sep 23, 2022
ad5e363
Merge branch 'main' into plotaccessor
headtr1ck Sep 23, 2022
27fa07f
update whats-new
headtr1ck Sep 23, 2022
47ef0ce
rename xr.plot.plot module since it shadows the xr.plot.plot method
headtr1ck Sep 25, 2022
45b154a
move accessor to its own module
headtr1ck Sep 25, 2022
e158bf3
move DSPlotAccessor to accessor module
headtr1ck Sep 25, 2022
cadc6de
fix DSPlotAccessor import
headtr1ck Sep 25, 2022
44a0317
add explanation to import statement
headtr1ck Sep 25, 2022
f68a5da
add breaking change to whats-new
headtr1ck Sep 25, 2022
39cb308
Merge branch 'main' into plotaccessor
headtr1ck Sep 25, 2022
5f38366
remove unused `rtol` argument from plot
headtr1ck Sep 25, 2022
e4f792b
make most arguments of plotmethods kwargs only
headtr1ck Sep 25, 2022
226f0e7
fix wrong return types
headtr1ck Sep 25, 2022
84a9ae9
add breaking kwarg change to whats-new
headtr1ck Sep 25, 2022
c1979f5
Merge branch 'main' into plotaccessor
headtr1ck Oct 3, 2022
0758b3b
support for aspect='auto' or 'equal
headtr1ck Oct 3, 2022
f9ce21b
typing support for Dataset FacetGrid
headtr1ck Oct 3, 2022
70c4771
deprecate positional arguments for all plot methods
headtr1ck Oct 3, 2022
f1df41c
add deprecation to whats-new
headtr1ck Oct 3, 2022
bf0ffd1
add FacetGrid generic type
headtr1ck Oct 4, 2022
c61ef0a
fix mypy 0.981 complaints
headtr1ck Oct 4, 2022
a4c7795
fix index errors in plots
headtr1ck Oct 4, 2022
8a557dd
Merge branch 'main' into plotaccessor
headtr1ck Oct 9, 2022
870d5a5
add overloads to scatter
headtr1ck Oct 9, 2022
9d0c859
deprecate scatter args
headtr1ck Oct 10, 2022
1b7e7db
add scatter to accessors and fix docstrings
headtr1ck Oct 10, 2022
ebec845
undo some breaking changes
headtr1ck Oct 11, 2022
90aded4
fix the docstrings and some typing
headtr1ck Oct 11, 2022
3da72e5
fix typing of scatter accessor funcs
headtr1ck Oct 11, 2022
8342f6c
align docstrings with signature and complete typing
headtr1ck Oct 11, 2022
9145f11
add remaining typing
headtr1ck Oct 11, 2022
2f01a17
align more docstrings
headtr1ck Oct 11, 2022
6be4352
re add ValueError for scatter plots with u, v
headtr1ck Oct 11, 2022
9d6a804
fix whats-new conflict
headtr1ck Oct 11, 2022
9686ce6
Merge branch 'main' into plotaccessor
headtr1ck Oct 12, 2022
f61c3d7
fix some typing errors
headtr1ck Oct 12, 2022
1bf0165
more typing fixes
headtr1ck Oct 12, 2022
a62a9a6
fix last mypy complaints
headtr1ck Oct 12, 2022
43b4e7e
try fixing facetgrid examples
headtr1ck Oct 12, 2022
48c9248
fix py3.8 problems
headtr1ck Oct 13, 2022
d101b8b
update plotting.rst
headtr1ck Oct 13, 2022
534c09a
update api
headtr1ck Oct 13, 2022
a0c6b14
update plot docstring
headtr1ck Oct 13, 2022
75f1425
add a tip about yincrease in imshow
headtr1ck Oct 13, 2022
0c25767
set default for x/yincrease in docstring
headtr1ck Oct 13, 2022
a514530
simplify typing
headtr1ck Oct 14, 2022
92462d9
add deprecation date as comment
headtr1ck Oct 14, 2022
8761264
Merge branch 'main' into plotaccessor
headtr1ck Oct 14, 2022
381f00f
update whats-new to new release
headtr1ck Oct 14, 2022
f621ef2
fix whats-new
headtr1ck Oct 14, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ci/requirements/min-all-deps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ channels:
- conda-forge
- nodefaults
dependencies:
# MINIMUM VERSIONS POLICY: see doc/installing.rst
# MINIMUM VERSIONS POLICY: see doc/user-guide/installing.rst
# Run ci/min_deps_check.py to verify that this file respects the policy.
# When upgrading python, numpy, or pandas, must also change
# doc/installing.rst and setup.py.
# doc/user-guide/installing.rst, doc/user-guide/plotting.rst and setup.py.
- python=3.8
- boto3=1.18
- bottleneck=1.3
Expand Down
5 changes: 0 additions & 5 deletions doc/api-hidden.rst
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,6 @@
plot.scatter
plot.surface

plot.FacetGrid.map_dataarray
plot.FacetGrid.set_titles
plot.FacetGrid.set_ticks
plot.FacetGrid.map

CFTimeIndex.all
CFTimeIndex.any
CFTimeIndex.append
Expand Down
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ DataArray
DataArray.plot.line
DataArray.plot.pcolormesh
DataArray.plot.step
DataArray.plot.scatter
DataArray.plot.surface


Expand All @@ -719,6 +720,7 @@ Faceting
plot.FacetGrid.map_dataarray
plot.FacetGrid.map_dataarray_line
plot.FacetGrid.map_dataset
plot.FacetGrid.map_plot1d
plot.FacetGrid.set_axis_labels
plot.FacetGrid.set_ticks
plot.FacetGrid.set_titles
Expand Down
50 changes: 34 additions & 16 deletions doc/user-guide/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Matplotlib must be installed before xarray can plot.

To use xarray's plotting capabilities with time coordinates containing
``cftime.datetime`` objects
`nc-time-axis <https://github.com/SciTools/nc-time-axis>`_ v1.2.0 or later
`nc-time-axis <https://github.com/SciTools/nc-time-axis>`_ v1.3.0 or later
needs to be installed.

For more extensive plotting applications consider the following projects:
Expand Down Expand Up @@ -106,7 +106,13 @@ The simplest way to make a plot is to call the :py:func:`DataArray.plot()` metho
@savefig plotting_1d_simple.png width=4in
air1d.plot()

Xarray uses the coordinate name along with metadata ``attrs.long_name``, ``attrs.standard_name``, ``DataArray.name`` and ``attrs.units`` (if available) to label the axes. The names ``long_name``, ``standard_name`` and ``units`` are copied from the `CF-conventions spec <https://cfconventions.org/Data/cf-conventions/cf-conventions-1.7/build/ch03s03.html>`_. When choosing names, the order of precedence is ``long_name``, ``standard_name`` and finally ``DataArray.name``. The y-axis label in the above plot was constructed from the ``long_name`` and ``units`` attributes of ``air1d``.
Xarray uses the coordinate name along with metadata ``attrs.long_name``,
``attrs.standard_name``, ``DataArray.name`` and ``attrs.units`` (if available)
to label the axes.
The names ``long_name``, ``standard_name`` and ``units`` are copied from the
`CF-conventions spec <https://cfconventions.org/Data/cf-conventions/cf-conventions-1.7/build/ch03s03.html>`_.
When choosing names, the order of precedence is ``long_name``, ``standard_name`` and finally ``DataArray.name``.
The y-axis label in the above plot was constructed from the ``long_name`` and ``units`` attributes of ``air1d``.

.. ipython:: python

Expand Down Expand Up @@ -340,7 +346,10 @@ The keyword arguments ``xincrease`` and ``yincrease`` let you control the axes d
y="lat", hue="lon", xincrease=False, yincrease=False
)

In addition, one can use ``xscale, yscale`` to set axes scaling; ``xticks, yticks`` to set axes ticks and ``xlim, ylim`` to set axes limits. These accept the same values as the matplotlib methods ``Axes.set_(x,y)scale()``, ``Axes.set_(x,y)ticks()``, ``Axes.set_(x,y)lim()`` respectively.
In addition, one can use ``xscale, yscale`` to set axes scaling;
``xticks, yticks`` to set axes ticks and ``xlim, ylim`` to set axes limits.
These accept the same values as the matplotlib methods ``Axes.set_(x,y)scale()``,
``Axes.set_(x,y)ticks()``, ``Axes.set_(x,y)lim()`` respectively.


Two Dimensions
Expand All @@ -350,7 +359,8 @@ Two Dimensions
Simple Example
================

The default method :py:meth:`DataArray.plot` calls :py:func:`xarray.plot.pcolormesh` by default when the data is two-dimensional.
The default method :py:meth:`DataArray.plot` calls :py:func:`xarray.plot.pcolormesh`
by default when the data is two-dimensional.

.. ipython:: python
:okwarning:
Expand Down Expand Up @@ -585,7 +595,10 @@ Faceting here refers to splitting an array along one or two dimensions and
plotting each group.
Xarray's basic plotting is useful for plotting two dimensional arrays. What
about three or four dimensional arrays? That's where facets become helpful.
The general approach to plotting here is called “small multiples”, where the same kind of plot is repeated multiple times, and the specific use of small multiples to display the same relationship conditioned on one or more other variables is often called a “trellis plot”.
The general approach to plotting here is called “small multiples”, where the
same kind of plot is repeated multiple times, and the specific use of small
multiples to display the same relationship conditioned on one or more other
variables is often called a “trellis plot”.

Consider the temperature data set. There are 4 observations per day for two
years which makes for 2920 values along the time dimension.
Expand Down Expand Up @@ -670,8 +683,8 @@ Faceted plotting supports other arguments common to xarray 2d plots.

@savefig plot_facet_robust.png
g = hasoutliers.plot.pcolormesh(
"lon",
"lat",
x="lon",
y="lat",
col="time",
col_wrap=3,
robust=True,
Expand Down Expand Up @@ -711,7 +724,7 @@ they have been plotted.
.. ipython:: python
:okwarning:

g = t.plot.imshow("lon", "lat", col="time", col_wrap=3, robust=True)
g = t.plot.imshow(x="lon", y="lat", col="time", col_wrap=3, robust=True)

for i, ax in enumerate(g.axes.flat):
ax.set_title("Air Temperature %d" % i)
Expand All @@ -727,7 +740,8 @@ they have been plotted.
axis labels, axis ticks and plot titles. See :py:meth:`~xarray.plot.FacetGrid.set_titles`,
:py:meth:`~xarray.plot.FacetGrid.set_xlabels`, :py:meth:`~xarray.plot.FacetGrid.set_ylabels` and
:py:meth:`~xarray.plot.FacetGrid.set_ticks` for more information.
Plotting functions can be applied to each subset of the data by calling :py:meth:`~xarray.plot.FacetGrid.map_dataarray` or to each subplot by calling :py:meth:`~xarray.plot.FacetGrid.map`.
Plotting functions can be applied to each subset of the data by calling
:py:meth:`~xarray.plot.FacetGrid.map_dataarray` or to each subplot by calling :py:meth:`~xarray.plot.FacetGrid.map`.

TODO: add an example of using the ``map`` method to plot dataset variables
(e.g., with ``plt.quiver``).
Expand Down Expand Up @@ -777,7 +791,8 @@ Additionally, the boolean kwarg ``add_guide`` can be used to prevent the display
@savefig ds_discrete_legend_hue_scatter.png
ds.plot.scatter(x="A", y="B", hue="w", hue_style="discrete")

The ``markersize`` kwarg lets you vary the point's size by variable value. You can additionally pass ``size_norm`` to control how the variable's values are mapped to point sizes.
The ``markersize`` kwarg lets you vary the point's size by variable value.
You can additionally pass ``size_norm`` to control how the variable's values are mapped to point sizes.

.. ipython:: python
:okwarning:
Expand All @@ -794,7 +809,8 @@ Faceting is also possible
ds.plot.scatter(x="A", y="B", col="x", row="z", hue="w", hue_style="discrete")


For more advanced scatter plots, we recommend converting the relevant data variables to a pandas DataFrame and using the extensive plotting capabilities of ``seaborn``.
For more advanced scatter plots, we recommend converting the relevant data variables
to a pandas DataFrame and using the extensive plotting capabilities of ``seaborn``.

Quiver
~~~~~~
Expand All @@ -816,7 +832,8 @@ where ``u`` and ``v`` denote the x and y direction components of the arrow vecto
@savefig ds_facet_quiver.png
ds.plot.quiver(x="x", y="y", u="A", v="B", col="w", row="z", scale=4)

``scale`` is required for faceted quiver plots. The scale determines the number of data units per arrow length unit, i.e. a smaller scale parameter makes the arrow longer.
``scale`` is required for faceted quiver plots.
The scale determines the number of data units per arrow length unit, i.e. a smaller scale parameter makes the arrow longer.

Streamplot
~~~~~~~~~~
Expand All @@ -830,7 +847,8 @@ Visualizing vector fields is also supported with streamline plots:
ds.isel(w=1, z=1).plot.streamplot(x="x", y="y", u="A", v="B")


where ``u`` and ``v`` denote the x and y direction components of the vectors tangent to the streamlines. Again, faceting is also possible:
where ``u`` and ``v`` denote the x and y direction components of the vectors tangent to the streamlines.
Again, faceting is also possible:

.. ipython:: python
:okwarning:
Expand Down Expand Up @@ -983,7 +1001,7 @@ instead of the default ones:
)

@savefig plotting_example_2d_irreg.png width=4in
da.plot.pcolormesh("lon", "lat")
da.plot.pcolormesh(x="lon", y="lat")

Note that in this case, xarray still follows the pixel centered convention.
This might be undesirable in some cases, for example when your data is defined
Expand All @@ -996,7 +1014,7 @@ this convention when plotting on a map:
import cartopy.crs as ccrs

ax = plt.subplot(projection=ccrs.PlateCarree())
da.plot.pcolormesh("lon", "lat", ax=ax)
da.plot.pcolormesh(x="lon", y="lat", ax=ax)
ax.scatter(lon, lat, transform=ccrs.PlateCarree())
ax.coastlines()
@savefig plotting_example_2d_irreg_map.png width=4in
Expand All @@ -1009,7 +1027,7 @@ You can however decide to infer the cell boundaries and use the
:okwarning:

ax = plt.subplot(projection=ccrs.PlateCarree())
da.plot.pcolormesh("lon", "lat", ax=ax, infer_intervals=True)
da.plot.pcolormesh(x="lon", y="lat", ax=ax, infer_intervals=True)
ax.scatter(lon, lat, transform=ccrs.PlateCarree())
ax.coastlines()
@savefig plotting_example_2d_irreg_map_infer.png width=4in
Expand Down
12 changes: 10 additions & 2 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,22 @@ v2022.10.1 (unreleased)
New Features
~~~~~~~~~~~~

- Add static typing to plot accessors (:issue:`6949`, :pull:`7052`).
By `Michael Niklas <https://github.com/headtr1ck>`_.

Breaking changes
~~~~~~~~~~~~~~~~

- Many arguments of plotmethods have been made keyword-only.
- ``xarray.plot.plot`` module renamed to ``xarray.plot.dataarray_plot`` to prevent
shadowing of the ``plot`` method. (:issue:`6949`, :pull:`7052`).
By `Michael Niklas <https://github.com/headtr1ck>`_.

Deprecations
~~~~~~~~~~~~

- Positional arguments for all plot methods have been deprecated (:issue:`6949`, :pull:`7052`).
By `Michael Niklas <https://github.com/headtr1ck>`_.

Bug fixes
~~~~~~~~~
Expand Down Expand Up @@ -64,8 +72,8 @@ New Features
the z argument. (:pull:`6778`)
By `Jimmy Westling <https://github.com/illviljan>`_.
- Include the variable name in the error message when CF decoding fails to allow
for easier identification of problematic variables (:issue:`7145`,
:pull:`7147`). By `Spencer Clark <https://github.com/spencerkclark>`_.
for easier identification of problematic variables (:issue:`7145`, :pull:`7147`).
By `Spencer Clark <https://github.com/spencerkclark>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ module = [
"importlib_metadata.*",
"iris.*",
"matplotlib.*",
"mpl_toolkits.*",
"Nio.*",
"nc_time_axis.*",
"numbagg.*",
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ float_to_top = true
default_section = THIRDPARTY
known_first_party = xarray


[aliases]
test = pytest

Expand Down
14 changes: 8 additions & 6 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset
from .types import JoinOptions, T_DataArray, T_DataArrayOrSet, T_Dataset
from .types import JoinOptions, T_DataArray, T_Dataset, T_DataWithCoords

DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords)

Expand Down Expand Up @@ -944,8 +944,8 @@ def _get_broadcast_dims_map_common_coords(args, exclude):


def _broadcast_helper(
arg: T_DataArrayOrSet, exclude, dims_map, common_coords
) -> T_DataArrayOrSet:
arg: T_DataWithCoords, exclude, dims_map, common_coords
) -> T_DataWithCoords:

from .dataarray import DataArray
from .dataset import Dataset
Expand Down Expand Up @@ -976,14 +976,16 @@ def _broadcast_dataset(ds: T_Dataset) -> T_Dataset:

# remove casts once https://github.com/python/mypy/issues/12800 is resolved
if isinstance(arg, DataArray):
return cast("T_DataArrayOrSet", _broadcast_array(arg))
return cast("T_DataWithCoords", _broadcast_array(arg))
elif isinstance(arg, Dataset):
return cast("T_DataArrayOrSet", _broadcast_dataset(arg))
return cast("T_DataWithCoords", _broadcast_dataset(arg))
else:
raise ValueError("all input must be Dataset or DataArray objects")


def broadcast(*args, exclude=None):
# TODO: this typing is too restrictive since it cannot deal with mixed
# DataArray and Dataset types...? Is this a problem?
def broadcast(*args: T_DataWithCoords, exclude=None) -> tuple[T_DataWithCoords, ...]:
"""Explicitly broadcast any number of DataArray or Dataset objects against
one another.

Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ..coding.calendar_ops import convert_calendar, interp_calendar
from ..coding.cftimeindex import CFTimeIndex
from ..plot.plot import _PlotMethods
from ..plot.accessor import DataArrayPlotAccessor
from ..plot.utils import _get_units_from_attrs
from . import alignment, computation, dtypes, indexing, ops, utils
from ._reductions import DataArrayReductions
Expand Down Expand Up @@ -4189,7 +4189,7 @@ def _inplace_binary_op(self: T_DataArray, other: Any, f: Callable) -> T_DataArra
def _copy_attrs_from(self, other: DataArray | Dataset | Variable) -> None:
self.attrs = other.attrs

plot = utils.UncachedAccessor(_PlotMethods)
plot = utils.UncachedAccessor(DataArrayPlotAccessor)

def _title_for_slice(self, truncate: int = 50) -> str:
"""
Expand Down
8 changes: 5 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from ..coding.calendar_ops import convert_calendar, interp_calendar
from ..coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings
from ..plot.dataset_plot import _Dataset_PlotMethods
from ..plot.accessor import DatasetPlotAccessor
from . import alignment
from . import dtypes as xrdtypes
from . import duck_array_ops, formatting, formatting_html, ops, utils
Expand Down Expand Up @@ -7483,7 +7483,7 @@ def imag(self: T_Dataset) -> T_Dataset:
"""
return self.map(lambda x: x.imag, keep_attrs=True)

plot = utils.UncachedAccessor(_Dataset_PlotMethods)
plot = utils.UncachedAccessor(DatasetPlotAccessor)

def filter_by_attrs(self: T_Dataset, **kwargs) -> T_Dataset:
"""Returns a ``Dataset`` with variables that match specific conditions.
Expand Down Expand Up @@ -8575,7 +8575,9 @@ def curvefit(
or not isinstance(coords, Iterable)
):
coords = [coords]
coords_ = [self[coord] if isinstance(coord, str) else coord for coord in coords]
coords_: Sequence[DataArray] = [
self[coord] if isinstance(coord, str) else coord for coord in coords
]

# Determine whether any coords are dims on self
for coord in coords_:
Expand Down
10 changes: 9 additions & 1 deletion xarray/core/pycompat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from importlib import import_module
from typing import Any, Literal

import numpy as np
from packaging.version import Version
Expand All @@ -9,6 +10,8 @@

integer_types = (int, np.integer)

ModType = Literal["dask", "pint", "cupy", "sparse"]


class DuckArrayModule:
"""
Expand All @@ -18,7 +21,12 @@ class DuckArrayModule:
https://github.com/pydata/xarray/pull/5561#discussion_r664815718
"""

def __init__(self, mod):
module: ModType | None
version: Version
type: tuple[type[Any]] # TODO: improve this? maybe Generic
available: bool

def __init__(self, mod: ModType) -> None:
try:
duck_array_module = import_module(mod)
duck_array_version = Version(duck_array_module.__version__)
Expand Down
3 changes: 3 additions & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,10 @@ def dtype(self) -> np.dtype:
CoarsenBoundaryOptions = Literal["exact", "trim", "pad"]
SideOptions = Literal["left", "right"]

ScaleOptions = Literal["linear", "symlog", "log", "logit", None]
HueStyleOptions = Literal["continuous", "discrete", None]
AspectOptions = Union[Literal["auto", "equal"], float, None]
ExtendOptions = Literal["neither", "both", "min", "max", None]

# TODO: Wait until mypy supports recursive objects in combination with typevars
_T = TypeVar("_T")
Expand Down
Loading