diff --git a/.binder/environment.yml b/.binder/environment.yml
index 6fd5829c5e6..6caea42df87 100644
--- a/.binder/environment.yml
+++ b/.binder/environment.yml
@@ -2,7 +2,7 @@ name: xarray-examples
channels:
- conda-forge
dependencies:
- - python=3.8
+ - python=3.9
- boto3
- bottleneck
- cartopy
@@ -26,6 +26,7 @@ dependencies:
- pandas
- pint
- pip
+ - pooch
- pydap
- pynio
- rasterio
diff --git a/.github/workflows/cancel-duplicate-runs.yaml b/.github/workflows/cancel-duplicate-runs.yaml
index 46637bdc112..9f74360b034 100644
--- a/.github/workflows/cancel-duplicate-runs.yaml
+++ b/.github/workflows/cancel-duplicate-runs.yaml
@@ -10,6 +10,6 @@ jobs:
runs-on: ubuntu-latest
if: github.repository == 'pydata/xarray'
steps:
- - uses: styfle/cancel-workflow-action@0.9.0
+ - uses: styfle/cancel-workflow-action@0.9.1
with:
workflow_id: ${{ github.event.workflow.id }}
diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml
index 2b9a6405f21..ed731b25f76 100644
--- a/.github/workflows/ci-additional.yaml
+++ b/.github/workflows/ci-additional.yaml
@@ -103,7 +103,7 @@ jobs:
$PYTEST_EXTRA_FLAGS
- name: Upload code coverage to Codecov
- uses: codecov/codecov-action@v1
+ uses: codecov/codecov-action@v2.0.2
with:
file: ./coverage.xml
flags: unittests,${{ matrix.env }}
diff --git a/.github/workflows/ci-pre-commit-autoupdate.yaml b/.github/workflows/ci-pre-commit-autoupdate.yaml
index 8ba7ac14ef1..b10a541197e 100644
--- a/.github/workflows/ci-pre-commit-autoupdate.yaml
+++ b/.github/workflows/ci-pre-commit-autoupdate.yaml
@@ -35,7 +35,6 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
EXECUTE_COMMANDS: |
python -m pre_commit autoupdate
- python .github/workflows/sync_linter_versions.py .pre-commit-config.yaml ci/requirements/mypy_only
python -m pre_commit run --all-files
COMMIT_MESSAGE: 'pre-commit: autoupdate hook versions'
COMMIT_NAME: 'github-actions[bot]'
diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 3918f92574d..22a05eb1fc0 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -100,7 +100,7 @@ jobs:
path: pytest.xml
- name: Upload code coverage to Codecov
- uses: codecov/codecov-action@v1
+ uses: codecov/codecov-action@v2.0.2
with:
file: ./coverage.xml
flags: unittests
diff --git a/.github/workflows/sync_linter_versions.py b/.github/workflows/sync_linter_versions.py
deleted file mode 100755
index cb0b1355c71..00000000000
--- a/.github/workflows/sync_linter_versions.py
+++ /dev/null
@@ -1,76 +0,0 @@
-#!/usr/bin/env python
-import argparse
-import itertools
-import pathlib
-import re
-
-import yaml
-from packaging import version
-from packaging.requirements import Requirement
-
-operator_re = re.compile("=+")
-
-
-def extract_versions(config):
- repos = config.get("repos")
- if repos is None:
- raise ValueError("invalid pre-commit configuration")
-
- extracted_versions = (
- ((hook["id"], version.parse(repo["rev"])) for hook in repo["hooks"])
- for repo in repos
- )
- return dict(itertools.chain.from_iterable(extracted_versions))
-
-
-def update_requirement(line, new_versions):
- # convert to pep-508 compatible
- preprocessed = operator_re.sub("==", line)
- requirement = Requirement(preprocessed)
-
- specifier, *_ = requirement.specifier
- old_version = specifier.version
- new_version = new_versions.get(requirement.name, old_version)
-
- new_line = f"{requirement.name}={new_version}"
-
- return new_line
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--dry", action="store_true")
- parser.add_argument(
- metavar="pre-commit-config", dest="pre_commit_config", type=pathlib.Path
- )
- parser.add_argument("requirements", type=pathlib.Path)
- args = parser.parse_args()
-
- with args.pre_commit_config.open() as f:
- config = yaml.safe_load(f)
-
- versions = extract_versions(config)
- mypy_version = versions["mypy"]
-
- requirements_text = args.requirements.read_text()
- requirements = requirements_text.split("\n")
- new_requirements = [
- update_requirement(line, versions)
- if line and not line.startswith("# ")
- else line
- for line in requirements
- ]
- new_requirements_text = "\n".join(new_requirements)
-
- if args.dry:
- separator = "\n" + "—" * 80 + "\n"
- print(
- "contents of the old requirements file:",
- requirements_text,
- "contents of the new requirements file:",
- new_requirements_text,
- sep=separator,
- end=separator,
- )
- else:
- args.requirements.write_text(new_requirements_text)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 42d1fb0a4a5..53525d0def9 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -8,16 +8,16 @@ repos:
- id: check-yaml
# isort should run before black as black sometimes tweaks the isort output
- repo: https://github.com/PyCQA/isort
- rev: 5.9.1
+ rev: 5.9.3
hooks:
- id: isort
# https://github.com/python/black#version-control-integration
- repo: https://github.com/psf/black
- rev: 21.6b0
+ rev: 21.7b0
hooks:
- id: black
- repo: https://github.com/keewis/blackdoc
- rev: v0.3.3
+ rev: v0.3.4
hooks:
- id: blackdoc
- repo: https://gitlab.com/pycqa/flake8
@@ -30,7 +30,6 @@ repos:
# - id: velin
# args: ["--write", "--compact"]
- repo: https://github.com/pre-commit/mirrors-mypy
- # version must correspond to the one in .github/workflows/ci-additional.yaml
rev: v0.910
hooks:
- id: mypy
@@ -44,6 +43,7 @@ repos:
types-pytz,
# Dependencies that are typed
numpy,
+ typing-extensions==3.10.0.0,
]
# run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194
# - repo: https://github.com/asottile/pyupgrade
diff --git a/asv_bench/benchmarks/repr.py b/asv_bench/benchmarks/repr.py
index 617e9313fd1..405f6cd0530 100644
--- a/asv_bench/benchmarks/repr.py
+++ b/asv_bench/benchmarks/repr.py
@@ -1,8 +1,30 @@
+import numpy as np
import pandas as pd
import xarray as xr
+class Repr:
+ def setup(self):
+ a = np.arange(0, 100)
+ data_vars = dict()
+ for i in a:
+ data_vars[f"long_variable_name_{i}"] = xr.DataArray(
+ name=f"long_variable_name_{i}",
+ data=np.arange(0, 20),
+ dims=[f"long_coord_name_{i}_x"],
+ coords={f"long_coord_name_{i}_x": np.arange(0, 20) * 2},
+ )
+ self.ds = xr.Dataset(data_vars)
+ self.ds.attrs = {f"attr_{k}": 2 for k in a}
+
+ def time_repr(self):
+ repr(self.ds)
+
+ def time_repr_html(self):
+ self.ds._repr_html_()
+
+
class ReprMultiIndex:
def setup(self):
index = pd.MultiIndex.from_product(
diff --git a/ci/install-upstream-wheels.sh b/ci/install-upstream-wheels.sh
index 073b28b8cfb..92a0f8fc7e7 100755
--- a/ci/install-upstream-wheels.sh
+++ b/ci/install-upstream-wheels.sh
@@ -7,6 +7,7 @@ conda uninstall -y --force \
matplotlib \
dask \
distributed \
+ fsspec \
zarr \
cftime \
rasterio \
@@ -40,4 +41,5 @@ python -m pip install \
git+https://github.com/mapbox/rasterio \
git+https://github.com/hgrecco/pint \
git+https://github.com/pydata/bottleneck \
- git+https://github.com/pydata/sparse
+ git+https://github.com/pydata/sparse \
+ git+https://github.com/intake/filesystem_spec
diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml
index fc32d35837b..78ead40d5a2 100644
--- a/ci/requirements/environment-windows.yml
+++ b/ci/requirements/environment-windows.yml
@@ -10,6 +10,7 @@ dependencies:
- cftime
- dask
- distributed
+ - fsspec!=2021.7.0
- h5netcdf
- h5py
- hdf5
diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml
index c8afc3c21bb..f64ca3677cc 100644
--- a/ci/requirements/environment.yml
+++ b/ci/requirements/environment.yml
@@ -12,6 +12,7 @@ dependencies:
- cftime
- dask
- distributed
+ - fsspec!=2021.7.0
- h5netcdf
- h5py
- hdf5
diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst
index 076b0eb452a..fc27d9c3fe8 100644
--- a/doc/api-hidden.rst
+++ b/doc/api-hidden.rst
@@ -54,7 +54,6 @@
core.rolling.DatasetCoarsen.var
core.rolling.DatasetCoarsen.boundary
core.rolling.DatasetCoarsen.coord_func
- core.rolling.DatasetCoarsen.keep_attrs
core.rolling.DatasetCoarsen.obj
core.rolling.DatasetCoarsen.side
core.rolling.DatasetCoarsen.trim_excess
@@ -120,7 +119,6 @@
core.rolling.DatasetRolling.var
core.rolling.DatasetRolling.center
core.rolling.DatasetRolling.dim
- core.rolling.DatasetRolling.keep_attrs
core.rolling.DatasetRolling.min_periods
core.rolling.DatasetRolling.obj
core.rolling.DatasetRolling.rollings
@@ -199,7 +197,6 @@
core.rolling.DataArrayCoarsen.var
core.rolling.DataArrayCoarsen.boundary
core.rolling.DataArrayCoarsen.coord_func
- core.rolling.DataArrayCoarsen.keep_attrs
core.rolling.DataArrayCoarsen.obj
core.rolling.DataArrayCoarsen.side
core.rolling.DataArrayCoarsen.trim_excess
@@ -263,7 +260,6 @@
core.rolling.DataArrayRolling.var
core.rolling.DataArrayRolling.center
core.rolling.DataArrayRolling.dim
- core.rolling.DataArrayRolling.keep_attrs
core.rolling.DataArrayRolling.min_periods
core.rolling.DataArrayRolling.obj
core.rolling.DataArrayRolling.window
diff --git a/doc/api.rst b/doc/api.rst
index bb3a99bfbb0..fb2296d1226 100644
--- a/doc/api.rst
+++ b/doc/api.rst
@@ -24,7 +24,6 @@ Top-level functions
combine_by_coords
combine_nested
where
- set_options
infer_freq
full_like
zeros_like
@@ -686,6 +685,7 @@ Dataset methods
open_zarr
Dataset.to_netcdf
Dataset.to_pandas
+ Dataset.as_numpy
Dataset.to_zarr
save_mfdataset
Dataset.to_array
@@ -716,6 +716,8 @@ DataArray methods
DataArray.to_pandas
DataArray.to_series
DataArray.to_dataframe
+ DataArray.to_numpy
+ DataArray.as_numpy
DataArray.to_index
DataArray.to_masked_array
DataArray.to_cdms2
diff --git a/doc/conf.py b/doc/conf.py
index f6f7abd61b2..0a6d1504161 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -313,7 +313,7 @@
"pandas": ("https://pandas.pydata.org/pandas-docs/stable", None),
"iris": ("https://scitools-iris.readthedocs.io/en/latest", None),
"numpy": ("https://numpy.org/doc/stable", None),
- "scipy": ("https://docs.scipy.org/doc/scipy/reference", None),
+ "scipy": ("https://docs.scipy.org/doc/scipy", None),
"numba": ("https://numba.pydata.org/numba-doc/latest", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
"dask": ("https://docs.dask.org/en/latest", None),
diff --git a/doc/getting-started-guide/installing.rst b/doc/getting-started-guide/installing.rst
index f3d3c0f1902..506236f3b9a 100644
--- a/doc/getting-started-guide/installing.rst
+++ b/doc/getting-started-guide/installing.rst
@@ -8,7 +8,6 @@ Required dependencies
- Python (3.7 or later)
- setuptools (40.4 or later)
-- typing-extensions (3.10 or later)
- `numpy `__ (1.17 or later)
- `pandas `__ (1.0 or later)
@@ -96,7 +95,7 @@ dependencies:
- **setuptools:** 42 months (but no older than 40.4)
- **numpy:** 18 months
(`NEP-29 `_)
-- **dask and dask.distributed:** 12 months (but no older than 2.9)
+- **dask and dask.distributed:** 12 months
- **sparse, pint** and other libraries that rely on
`NEP-18 `_
for integration: very latest available versions only, until the technology will have
diff --git a/doc/whats-new.rst b/doc/whats-new.rst
index 6cbb284566f..7b46633d293 100644
--- a/doc/whats-new.rst
+++ b/doc/whats-new.rst
@@ -14,13 +14,68 @@ What's New
np.random.seed(123456)
-.. _whats-new.0.18.3:
-v0.18.3 (unreleased)
+.. _whats-new.0.19.1:
+
+v0.19.1 (unreleased)
---------------------
New Features
~~~~~~~~~~~~
+- Add a option to disable the use of ``bottleneck`` (:pull:`5560`)
+ By `Justus Magin `_.
+- Added ``**kwargs`` argument to :py:meth:`open_rasterio` to access overviews (:issue:`3269`).
+ By `Pushkar Kopparla `_.
+
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+
+Deprecations
+~~~~~~~~~~~~
+
+
+Bug fixes
+~~~~~~~~~
+
+
+Documentation
+~~~~~~~~~~~~~
+
+
+Internal Changes
+~~~~~~~~~~~~~~~~
+
+- Explicit indexes refactor: avoid ``len(index)`` in ``map_blocks`` (:pull:`5670`).
+ By `Deepak Cherian `_.
+- Explicit indexes refactor: decouple ``xarray.Index``` from ``xarray.Variable`` (:pull:`5636`).
+ By `Benoit Bovy `_.
+- Improve the performance of reprs for large datasets or dataarrays. (:pull:`5661`)
+ By `Jimmy Westling `_.
+
+.. _whats-new.0.19.0:
+
+v0.19.0 (23 July 2021)
+----------------------
+
+This release brings improvements to plotting of categorical data, the ability to specify how attributes
+are combined in xarray operations, a new high-level :py:func:`unify_chunks` function, as well as various
+deprecations, bug fixes, and minor improvements.
+
+
+Many thanks to the 29 contributors to this release!:
+
+Andrew Williams, Augustus, Aureliana Barghini, Benoit Bovy, crusaderky, Deepak Cherian, ellesmith88,
+Elliott Sales de Andrade, Giacomo Caria, github-actions[bot], Illviljan, Joeperdefloep, joooeey, Julia Kent,
+Julius Busecke, keewis, Mathias Hauser, Matthias Göbel, Mattia Almansi, Maximilian Roos, Peter Andreas Entschev,
+Ray Bell, Sander, Santiago Soler, Sebastian, Spencer Clark, Stephan Hoyer, Thomas Hirtz, Thomas Nicholas.
+
+New Features
+~~~~~~~~~~~~
+- Allow passing argument ``missing_dims`` to :py:meth:`Variable.transpose` and :py:meth:`Dataset.transpose`
+ (:issue:`5550`, :pull:`5586`)
+ By `Giacomo Caria `_.
- Allow passing a dictionary as coords to a :py:class:`DataArray` (:issue:`5527`,
reverts :pull:`1539`, which had deprecated this due to python's inconsistent ordering in earlier versions).
By `Sander van Rijn `_.
@@ -53,6 +108,10 @@ New Features
- Allow removal of the coordinate attribute ``coordinates`` on variables by setting ``.attrs['coordinates']= None``
(:issue:`5510`).
By `Elle Smith `_.
+- Added :py:meth:`DataArray.to_numpy`, :py:meth:`DataArray.as_numpy`, and :py:meth:`Dataset.as_numpy`. (:pull:`5568`).
+ By `Tom Nicholas `_.
+- Units in plot labels are now automatically inferred from wrapped :py:meth:`pint.Quantity` arrays. (:pull:`5561`).
+ By `Tom Nicholas `_.
Breaking changes
~~~~~~~~~~~~~~~~
@@ -62,6 +121,10 @@ Breaking changes
pre-existing array values. This is a safer default than the prior ``mode="a"``,
and allows for higher performance writes (:pull:`5252`).
By `Stephan Hoyer `_.
+- The main parameter to :py:func:`combine_by_coords` is renamed to `data_objects` instead
+ of `datasets` so anyone calling this method using a named parameter will need to update
+ the name accordingly (:issue:`3248`, :pull:`4696`).
+ By `Augustus Ijams `_.
Deprecations
~~~~~~~~~~~~
@@ -71,6 +134,10 @@ Performance
- Significantly faster unstacking to a ``sparse`` array. :pull:`5577`
By `Deepak Cherian `_.
+- Removed the deprecated ``dim`` kwarg to :py:func:`DataArray.integrate` (:pull:`5630`)
+- Removed the deprecated ``keep_attrs`` kwarg to :py:func:`DataArray.rolling` (:pull:`5630`)
+- Removed the deprecated ``keep_attrs`` kwarg to :py:func:`DataArray.coarsen` (:pull:`5630`)
+- Completed deprecation of passing an ``xarray.DataArray`` to :py:func:`Variable` - will now raise a ``TypeError`` (:pull:`5630`)
Bug fixes
~~~~~~~~~
@@ -90,10 +157,9 @@ Bug fixes
- Plotting a pcolormesh with ``xscale="log"`` and/or ``yscale="log"`` works as
expected after improving the way the interval breaks are generated (:issue:`5333`).
By `Santiago Soler `_
-
-
-Documentation
-~~~~~~~~~~~~~
+- :py:func:`combine_by_coords` can now handle combining a list of unnamed
+ ``DataArray`` as input (:issue:`3248`, :pull:`4696`).
+ By `Augustus Ijams `_.
Internal Changes
@@ -104,7 +170,6 @@ Internal Changes
- Publish test results & timings on each PR.
(:pull:`5537`)
By `Maximilian Roos `_.
-
- Explicit indexes refactor: add a ``xarray.Index.query()`` method in which
one may eventually provide a custom implementation of label-based data
selection (not ready yet for public use). Also refactor the internal,
@@ -149,22 +214,9 @@ New Features
- Raise more informative error when decoding time variables with invalid reference dates.
(:issue:`5199`, :pull:`5288`). By `Giacomo Caria `_.
-Breaking changes
-~~~~~~~~~~~~~~~~
-- The main parameter to :py:func:`combine_by_coords` is renamed to `data_objects` instead
- of `datasets` so anyone calling this method using a named parameter will need to update
- the name accordingly (:issue:`3248`, :pull:`4696`).
- By `Augustus Ijams `_.
-
-Deprecations
-~~~~~~~~~~~~
-
Bug fixes
~~~~~~~~~
-- :py:func:`combine_by_coords` can now handle combining a list of unnamed
- ``DataArray`` as input (:issue:`3248`, :pull:`4696`).
- By `Augustus Ijams `_.
- Opening netCDF files from a path that doesn't end in ``.nc`` without supplying
an explicit ``engine`` works again (:issue:`5295`), fixing a bug introduced in
0.18.0.
diff --git a/setup.cfg b/setup.cfg
index 5a6e0b3435d..c44d207bf0f 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -67,6 +67,7 @@ classifiers =
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
+ Programming Language :: Python :: 3.10
Topic :: Scientific/Engineering
[options]
@@ -78,7 +79,6 @@ install_requires =
numpy >= 1.17
pandas >= 1.0
setuptools >= 40.4 # For pkg_resources
- typing-extensions >= 3.10 # Backported type hints
[options.extras_require]
io =
diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py
index 49a5a9ec7ae..1891fac8668 100644
--- a/xarray/backends/rasterio_.py
+++ b/xarray/backends/rasterio_.py
@@ -162,7 +162,14 @@ def default(s):
return parsed_meta
-def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, lock=None):
+def open_rasterio(
+ filename,
+ parse_coordinates=None,
+ chunks=None,
+ cache=None,
+ lock=None,
+ **kwargs,
+):
"""Open a file with rasterio (experimental).
This should work with any file that rasterio can open (most often:
@@ -272,7 +279,13 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc
if lock is None:
lock = RASTERIO_LOCK
- manager = CachingFileManager(rasterio.open, filename, lock=lock, mode="r")
+ manager = CachingFileManager(
+ rasterio.open,
+ filename,
+ lock=lock,
+ mode="r",
+ kwargs=kwargs,
+ )
riods = manager.acquire()
if vrt_params is not None:
riods = WarpedVRT(riods, **vrt_params)
diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py
index d492e3dfb92..aec12d2b154 100644
--- a/xarray/backends/zarr.py
+++ b/xarray/backends/zarr.py
@@ -737,6 +737,7 @@ def open_zarr(
See Also
--------
open_dataset
+ open_mfdataset
References
----------
diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py
index 8f2ba2f4b97..a53ac094253 100644
--- a/xarray/core/alignment.py
+++ b/xarray/core/alignment.py
@@ -18,7 +18,7 @@
import pandas as pd
from . import dtypes
-from .indexes import Index, PandasIndex, get_indexer_nd, wrap_pandas_index
+from .indexes import Index, PandasIndex, get_indexer_nd
from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index
from .variable import IndexVariable, Variable
@@ -53,7 +53,10 @@ def _get_joiner(join, index_cls):
def _override_indexes(objects, all_indexes, exclude):
for dim, dim_indexes in all_indexes.items():
if dim not in exclude:
- lengths = {index.size for index in dim_indexes}
+ lengths = {
+ getattr(index, "size", index.to_pandas_index().size)
+ for index in dim_indexes
+ }
if len(lengths) != 1:
raise ValueError(
f"Indexes along dimension {dim!r} don't have the same length."
@@ -300,16 +303,14 @@ def align(
joined_indexes = {}
for dim, matching_indexes in all_indexes.items():
if dim in indexes:
- # TODO: benbovy - flexible indexes. maybe move this logic in util func
- if isinstance(indexes[dim], Index):
- index = indexes[dim]
- else:
- index = PandasIndex(safe_cast_to_index(indexes[dim]))
+ index, _ = PandasIndex.from_pandas_index(
+ safe_cast_to_index(indexes[dim]), dim
+ )
if (
any(not index.equals(other) for other in matching_indexes)
or dim in unlabeled_dim_sizes
):
- joined_indexes[dim] = index
+ joined_indexes[dim] = indexes[dim]
else:
if (
any(
@@ -323,17 +324,18 @@ def align(
joiner = _get_joiner(join, type(matching_indexes[0]))
index = joiner(matching_indexes)
# make sure str coords are not cast to object
- index = maybe_coerce_to_str(index, all_coords[dim])
+ index = maybe_coerce_to_str(index.to_pandas_index(), all_coords[dim])
joined_indexes[dim] = index
else:
index = all_coords[dim][0]
if dim in unlabeled_dim_sizes:
unlabeled_sizes = unlabeled_dim_sizes[dim]
- # TODO: benbovy - flexible indexes: expose a size property for xarray.Index?
- # Some indexes may not have a defined size (e.g., built from multiple coords of
- # different sizes)
- labeled_size = index.size
+ # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647
+ if isinstance(index, PandasIndex):
+ labeled_size = index.to_pandas_index().size
+ else:
+ labeled_size = index.size
if len(unlabeled_sizes | {labeled_size}) > 1:
raise ValueError(
f"arguments without labels along dimension {dim!r} cannot be "
@@ -350,7 +352,14 @@ def align(
result = []
for obj in objects:
- valid_indexers = {k: v for k, v in joined_indexes.items() if k in obj.dims}
+ # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647
+ valid_indexers = {}
+ for k, index in joined_indexes.items():
+ if k in obj.dims:
+ if isinstance(index, Index):
+ valid_indexers[k] = index.to_pandas_index()
+ else:
+ valid_indexers[k] = index
if not valid_indexers:
# fast path for no reindexing necessary
new_obj = obj.copy(deep=copy)
@@ -471,7 +480,11 @@ def reindex_like_indexers(
ValueError
If any dimensions without labels have different sizes.
"""
- indexers = {k: v for k, v in other.xindexes.items() if k in target.dims}
+ # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647
+ # this doesn't support yet indexes other than pd.Index
+ indexers = {
+ k: v.to_pandas_index() for k, v in other.xindexes.items() if k in target.dims
+ }
for dim in other.dims:
if dim not in indexers and dim in target.dims:
@@ -560,7 +573,8 @@ def reindex_variables(
"from that to be indexed along {:s}".format(str(indexer.dims), dim)
)
- target = new_indexes[dim] = wrap_pandas_index(safe_cast_to_index(indexers[dim]))
+ target = safe_cast_to_index(indexers[dim])
+ new_indexes[dim] = PandasIndex(target, dim)
if dim in indexes:
# TODO (benbovy - flexible indexes): support other indexes than pd.Index?
diff --git a/xarray/core/combine.py b/xarray/core/combine.py
index de6d16ef5c3..7e1565e50de 100644
--- a/xarray/core/combine.py
+++ b/xarray/core/combine.py
@@ -77,9 +77,8 @@ def _infer_concat_order_from_coords(datasets):
"inferring concatenation order"
)
- # TODO (benbovy, flexible indexes): all indexes should be Pandas.Index
- # get pd.Index objects from Index objects
- indexes = [index.array for index in indexes]
+ # TODO (benbovy, flexible indexes): support flexible indexes?
+ indexes = [index.to_pandas_index() for index in indexes]
# If dimension coordinate values are same on every dataset then
# should be leaving this dimension alone (it's just a "bystander")
@@ -635,7 +634,7 @@ def _combine_single_variable_hypercube(
return concatenated
-# TODO remove empty list default param after version 0.19, see PR4696
+# TODO remove empty list default param after version 0.21, see PR4696
def combine_by_coords(
data_objects=[],
compat="no_conflicts",
@@ -849,11 +848,11 @@ def combine_by_coords(
precipitation (y, x) float64 0.4376 0.8918 0.9637 ... 0.5684 0.01879 0.6176
"""
- # TODO remove after version 0.19, see PR4696
+ # TODO remove after version 0.21, see PR4696
if datasets is not None:
warnings.warn(
"The datasets argument has been renamed to `data_objects`."
- " In future passing a value for datasets will raise an error."
+ " From 0.21 on passing a value for datasets will raise an error."
)
data_objects = datasets
diff --git a/xarray/core/common.py b/xarray/core/common.py
index 7b6e9198b43..ab822f576d3 100644
--- a/xarray/core/common.py
+++ b/xarray/core/common.py
@@ -821,7 +821,6 @@ def rolling(
dim: Mapping[Hashable, int] = None,
min_periods: int = None,
center: Union[bool, Mapping[Hashable, bool]] = False,
- keep_attrs: bool = None,
**window_kwargs: int,
):
"""
@@ -889,9 +888,7 @@ def rolling(
"""
dim = either_dict_or_kwargs(dim, window_kwargs, "rolling")
- return self._rolling_cls(
- self, dim, min_periods=min_periods, center=center, keep_attrs=keep_attrs
- )
+ return self._rolling_cls(self, dim, min_periods=min_periods, center=center)
def rolling_exp(
self,
@@ -940,7 +937,6 @@ def coarsen(
boundary: str = "exact",
side: Union[str, Mapping[Hashable, str]] = "left",
coord_func: str = "mean",
- keep_attrs: bool = None,
**window_kwargs: int,
):
"""
@@ -1009,7 +1005,6 @@ def coarsen(
boundary=boundary,
side=side,
coord_func=coord_func,
- keep_attrs=keep_attrs,
)
def resample(
diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py
index 87f67028862..5eeb22767c8 100644
--- a/xarray/core/dask_array_ops.py
+++ b/xarray/core/dask_array_ops.py
@@ -66,9 +66,15 @@ def push(array, n, axis):
)
if all(c == 1 for c in array.chunks[axis]):
array = array.rechunk({axis: 2})
- pushed = array.map_blocks(push, axis=axis, n=n)
+ pushed = array.map_blocks(push, axis=axis, n=n, dtype=array.dtype, meta=array._meta)
if len(array.chunks[axis]) > 1:
pushed = pushed.map_overlap(
- push, axis=axis, n=n, depth={axis: (1, 0)}, boundary="none"
+ push,
+ axis=axis,
+ n=n,
+ depth={axis: (1, 0)},
+ boundary="none",
+ dtype=array.dtype,
+ meta=array._meta,
)
return pushed
diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py
index b4d553c235a..900af885319 100644
--- a/xarray/core/dataarray.py
+++ b/xarray/core/dataarray.py
@@ -51,13 +51,7 @@
)
from .dataset import Dataset, split_indexes
from .formatting import format_item
-from .indexes import (
- Index,
- Indexes,
- default_indexes,
- propagate_indexes,
- wrap_pandas_index,
-)
+from .indexes import Index, Indexes, default_indexes, propagate_indexes
from .indexing import is_fancy_indexer
from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords
from .options import OPTIONS, _get_keep_attrs
@@ -426,12 +420,12 @@ def __init__(
self._close = None
def _replace(
- self,
+ self: T_DataArray,
variable: Variable = None,
coords=None,
name: Union[Hashable, None, Default] = _default,
indexes=None,
- ) -> "DataArray":
+ ) -> T_DataArray:
if variable is None:
variable = self.variable
if coords is None:
@@ -473,15 +467,14 @@ def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray":
return self
coords = self._coords.copy()
for name, idx in indexes.items():
- coords[name] = IndexVariable(name, idx)
+ coords[name] = IndexVariable(name, idx.to_pandas_index())
obj = self._replace(coords=coords)
# switch from dimension to level names, if necessary
dim_names: Dict[Any, str] = {}
for dim, idx in indexes.items():
- # TODO: benbovy - flexible indexes: update when MultiIndex has its own class
- pd_idx = idx.array
- if not isinstance(pd_idx, pd.MultiIndex) and pd_idx.name != dim:
+ pd_idx = idx.to_pandas_index()
+ if not isinstance(idx, pd.MultiIndex) and pd_idx.name != dim:
dim_names[dim] = idx.name
if dim_names:
obj = obj.rename(dim_names)
@@ -623,7 +616,16 @@ def __len__(self) -> int:
@property
def data(self) -> Any:
- """The array's data as a dask or numpy array"""
+ """
+ The DataArray's data as an array. The underlying array type
+ (e.g. dask, sparse, pint) is preserved.
+
+ See Also
+ --------
+ DataArray.to_numpy
+ DataArray.as_numpy
+ DataArray.values
+ """
return self.variable.data
@data.setter
@@ -632,13 +634,46 @@ def data(self, value: Any) -> None:
@property
def values(self) -> np.ndarray:
- """The array's data as a numpy.ndarray"""
+ """
+ The array's data as a numpy.ndarray.
+
+ If the array's data is not a numpy.ndarray this will attempt to convert
+ it naively using np.array(), which will raise an error if the array
+ type does not support coercion like this (e.g. cupy).
+ """
return self.variable.values
@values.setter
def values(self, value: Any) -> None:
self.variable.values = value
+ def to_numpy(self) -> np.ndarray:
+ """
+ Coerces wrapped data to numpy and returns a numpy.ndarray.
+
+ See also
+ --------
+ DataArray.as_numpy : Same but returns the surrounding DataArray instead.
+ Dataset.as_numpy
+ DataArray.values
+ DataArray.data
+ """
+ return self.variable.to_numpy()
+
+ def as_numpy(self: T_DataArray) -> T_DataArray:
+ """
+ Coerces wrapped data and coordinates into numpy arrays, returning a DataArray.
+
+ See also
+ --------
+ DataArray.to_numpy : Same but returns only the data as a numpy.ndarray object.
+ Dataset.as_numpy : Converts all variables in a Dataset.
+ DataArray.values
+ DataArray.data
+ """
+ coords = {k: v.as_numpy() for k, v in self._coords.items()}
+ return self._replace(self.variable.as_numpy(), coords, indexes=self._indexes)
+
@property
def _in_memory(self) -> bool:
return self.variable._in_memory
@@ -931,7 +966,7 @@ def persist(self, **kwargs) -> "DataArray":
ds = self._to_temp_dataset().persist(**kwargs)
return self._from_temp_dataset(ds)
- def copy(self, deep: bool = True, data: Any = None) -> "DataArray":
+ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray:
"""Returns a copy of this array.
If `deep=True`, a deep copy is made of the data array.
@@ -1004,12 +1039,7 @@ def copy(self, deep: bool = True, data: Any = None) -> "DataArray":
if self._indexes is None:
indexes = self._indexes
else:
- # TODO: benbovy: flexible indexes: support all xarray indexes (not just pandas.Index)
- # xarray Index needs a copy method.
- indexes = {
- k: wrap_pandas_index(v.to_pandas_index().copy(deep=deep))
- for k, v in self._indexes.items()
- }
+ indexes = {k: v.copy(deep=deep) for k, v in self._indexes.items()}
return self._replace(variable, coords, indexes=indexes)
def __copy__(self) -> "DataArray":
@@ -2742,7 +2772,7 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray:
result : MaskedArray
Masked where invalid values (nan or inf) occur.
"""
- values = self.values # only compute lazy arrays once
+ values = self.to_numpy() # only compute lazy arrays once
isnull = pd.isnull(values)
return np.ma.MaskedArray(data=values, mask=isnull, copy=copy)
@@ -3540,8 +3570,6 @@ def integrate(
self,
coord: Union[Hashable, Sequence[Hashable]] = None,
datetime_unit: str = None,
- *,
- dim: Union[Hashable, Sequence[Hashable]] = None,
) -> "DataArray":
"""Integrate along the given coordinate using the trapezoidal rule.
@@ -3553,8 +3581,6 @@ def integrate(
----------
coord : hashable, or sequence of hashable
Coordinate(s) used for the integration.
- dim : hashable, or sequence of hashable
- Coordinate(s) used for the integration.
datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \
'ps', 'fs', 'as'}, optional
Specify the unit if a datetime coordinate is used.
@@ -3591,21 +3617,6 @@ def integrate(
array([5.4, 6.6, 7.8])
Dimensions without coordinates: y
"""
- if dim is not None and coord is not None:
- raise ValueError(
- "Cannot pass both 'dim' and 'coord'. Please pass only 'coord' instead."
- )
-
- if dim is not None and coord is None:
- coord = dim
- msg = (
- "The `dim` keyword argument to `DataArray.integrate` is "
- "being replaced with `coord`, for consistency with "
- "`Dataset.integrate`. Please pass `coord` instead."
- " `dim` will be removed in version 0.19.0."
- )
- warnings.warn(msg, FutureWarning, stacklevel=2)
-
ds = self._to_temp_dataset().integrate(coord, datetime_unit)
return self._from_temp_dataset(ds)
diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index ce5c6ee720e..6b462202572 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -71,7 +71,6 @@
propagate_indexes,
remove_unused_levels_categories,
roll_index,
- wrap_pandas_index,
)
from .indexing import is_fancy_indexer
from .merge import (
@@ -1184,7 +1183,7 @@ def _overwrite_indexes(self, indexes: Mapping[Any, Index]) -> "Dataset":
variables = self._variables.copy()
new_indexes = dict(self.xindexes)
for name, idx in indexes.items():
- variables[name] = IndexVariable(name, idx)
+ variables[name] = IndexVariable(name, idx.to_pandas_index())
new_indexes[name] = idx
obj = self._replace(variables, indexes=new_indexes)
@@ -1323,6 +1322,18 @@ def copy(self, deep: bool = False, data: Mapping = None) -> "Dataset":
return self._replace(variables, attrs=attrs)
+ def as_numpy(self: "Dataset") -> "Dataset":
+ """
+ Coerces wrapped data and coordinates into numpy arrays, returning a Dataset.
+
+ See also
+ --------
+ DataArray.as_numpy
+ DataArray.to_numpy : Returns only the data as a numpy.ndarray object.
+ """
+ numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()}
+ return self._replace(variables=numpy_variables)
+
@property
def _level_coords(self) -> Dict[str, Hashable]:
"""Return a mapping of all MultiIndex levels and their corresponding
@@ -2462,6 +2473,10 @@ def sel(
pos_indexers, new_indexes = remap_label_indexers(
self, indexers=indexers, method=method, tolerance=tolerance
)
+ # TODO: benbovy - flexible indexes: also use variables returned by Index.query
+ # (temporary dirty fix).
+ new_indexes = {k: v[0] for k, v in new_indexes.items()}
+
result = self.isel(indexers=pos_indexers, drop=drop)
return result._overwrite_indexes(new_indexes)
@@ -3285,20 +3300,21 @@ def _rename_dims(self, name_dict):
return {name_dict.get(k, k): v for k, v in self.dims.items()}
def _rename_indexes(self, name_dict, dims_set):
+ # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5645
if self._indexes is None:
return None
indexes = {}
- for k, v in self.xindexes.items():
- # TODO: benbovy - flexible indexes: make it compatible with any xarray Index
- index = v.to_pandas_index()
+ for k, v in self.indexes.items():
new_name = name_dict.get(k, k)
if new_name not in dims_set:
continue
- if isinstance(index, pd.MultiIndex):
- new_names = [name_dict.get(k, k) for k in index.names]
- indexes[new_name] = PandasMultiIndex(index.rename(names=new_names))
+ if isinstance(v, pd.MultiIndex):
+ new_names = [name_dict.get(k, k) for k in v.names]
+ indexes[new_name] = PandasMultiIndex(
+ v.rename(names=new_names), new_name
+ )
else:
- indexes[new_name] = PandasIndex(index.rename(new_name))
+ indexes[new_name] = PandasIndex(v.rename(new_name), new_name)
return indexes
def _rename_all(self, name_dict, dims_dict):
@@ -3527,7 +3543,10 @@ def swap_dims(
if new_index.nlevels == 1:
# make sure index name matches dimension name
new_index = new_index.rename(k)
- indexes[k] = wrap_pandas_index(new_index)
+ if isinstance(new_index, pd.MultiIndex):
+ indexes[k] = PandasMultiIndex(new_index, k)
+ else:
+ indexes[k] = PandasIndex(new_index, k)
else:
var = v.to_base_variable()
var.dims = dims
@@ -3800,7 +3819,7 @@ def reorder_levels(
raise ValueError(f"coordinate {dim} has no MultiIndex")
new_index = index.reorder_levels(order)
variables[dim] = IndexVariable(coord.dims, new_index)
- indexes[dim] = PandasMultiIndex(new_index)
+ indexes[dim] = PandasMultiIndex(new_index, dim)
return self._replace(variables, indexes=indexes)
@@ -3828,7 +3847,7 @@ def _stack_once(self, dims, new_dim):
coord_names = set(self._coord_names) - set(dims) | {new_dim}
indexes = {k: v for k, v in self.xindexes.items() if k not in dims}
- indexes[new_dim] = wrap_pandas_index(idx)
+ indexes[new_dim] = PandasMultiIndex(idx, new_dim)
return self._replace_with_new_dims(
variables, coord_names=coord_names, indexes=indexes
@@ -4019,8 +4038,9 @@ def _unstack_once(
variables[name] = var
for name, lev in zip(index.names, index.levels):
- variables[name] = IndexVariable(name, lev)
- indexes[name] = PandasIndex(lev)
+ idx, idx_vars = PandasIndex.from_pandas_index(lev, name)
+ variables[name] = idx_vars[name]
+ indexes[name] = idx
coord_names = set(self._coord_names) - {dim} | set(index.names)
@@ -4058,8 +4078,9 @@ def _unstack_full_reindex(
variables[name] = var
for name, lev in zip(new_dim_names, index.levels):
- variables[name] = IndexVariable(name, lev)
- indexes[name] = PandasIndex(lev)
+ idx, idx_vars = PandasIndex.from_pandas_index(lev, name)
+ variables[name] = idx_vars[name]
+ indexes[name] = idx
coord_names = set(self._coord_names) - {dim} | set(new_dim_names)
@@ -4160,6 +4181,7 @@ def update(self, other: "CoercibleMapping") -> "Dataset":
"""Update this dataset's variables with those from another dataset.
Just like :py:meth:`dict.update` this is a in-place operation.
+ For a non-inplace version, see :py:meth:`Dataset.merge`.
Parameters
----------
@@ -4178,7 +4200,7 @@ def update(self, other: "CoercibleMapping") -> "Dataset":
Updated dataset. Note that since the update is in-place this is the input
dataset.
- It is deprecated since version 0.17 and scheduled to be removed in 0.19.
+ It is deprecated since version 0.17 and scheduled to be removed in 0.21.
Raises
------
@@ -4189,6 +4211,7 @@ def update(self, other: "CoercibleMapping") -> "Dataset":
See Also
--------
Dataset.assign
+ Dataset.merge
"""
merge_result = dataset_update_method(self, other)
return self._replace(inplace=True, **merge_result._asdict())
@@ -4262,6 +4285,10 @@ def merge(
------
MergeError
If any variables conflict (see ``compat``).
+
+ See Also
+ --------
+ Dataset.update
"""
other = other.to_dataset() if isinstance(other, xr.DataArray) else other
merge_result = dataset_merge_method(
@@ -4542,7 +4569,11 @@ def drop_dims(
drop_vars = {k for k, v in self._variables.items() if set(v.dims) & drop_dims}
return self.drop_vars(drop_vars)
- def transpose(self, *dims: Hashable) -> "Dataset":
+ def transpose(
+ self,
+ *dims: Hashable,
+ missing_dims: str = "raise",
+ ) -> "Dataset":
"""Return a new Dataset object with all array dimensions transposed.
Although the order of dimensions on each array will change, the dataset
@@ -4553,6 +4584,12 @@ def transpose(self, *dims: Hashable) -> "Dataset":
*dims : hashable, optional
By default, reverse the dimensions on each array. Otherwise,
reorder the dimensions to this order.
+ missing_dims : {"raise", "warn", "ignore"}, default: "raise"
+ What to do if dimensions that should be selected from are not present in the
+ Dataset:
+ - "raise": raise an exception
+ - "warn": raise a warning, and ignore the missing dimensions
+ - "ignore": ignore the missing dimensions
Returns
-------
@@ -4571,12 +4608,10 @@ def transpose(self, *dims: Hashable) -> "Dataset":
numpy.transpose
DataArray.transpose
"""
- if dims:
- if set(dims) ^ set(self.dims) and ... not in dims:
- raise ValueError(
- f"arguments to transpose ({dims}) must be "
- f"permuted dataset dimensions ({tuple(self.dims)})"
- )
+ # Use infix_dims to check once for missing dimensions
+ if len(dims) != 0:
+ _ = list(infix_dims(dims, self.dims, missing_dims))
+
ds = self.copy()
for name, var in self._variables.items():
var_dims = tuple(dim for dim in dims if dim in (var.dims + (...,)))
@@ -5812,10 +5847,13 @@ def diff(self, dim, n=1, label="upper"):
indexes = dict(self.xindexes)
if dim in indexes:
- # TODO: benbovy - flexible indexes: check slicing of xarray indexes?
- # or only allow this for pandas indexes?
- index = indexes[dim].to_pandas_index()
- indexes[dim] = PandasIndex(index[kwargs_new[dim]])
+ if isinstance(indexes[dim], PandasIndex):
+ # maybe optimize? (pandas index already indexed above with var.isel)
+ new_index = indexes[dim].index[kwargs_new[dim]]
+ if isinstance(new_index, pd.MultiIndex):
+ indexes[dim] = PandasMultiIndex(new_index, dim)
+ else:
+ indexes[dim] = PandasIndex(new_index, dim)
difference = self._replace_with_new_dims(variables, indexes=indexes)
@@ -6189,6 +6227,12 @@ def rank(self, dim, pct=False, keep_attrs=None):
ranked : Dataset
Variables that do not depend on `dim` are dropped.
"""
+ if not OPTIONS["use_bottleneck"]:
+ raise RuntimeError(
+ "rank requires bottleneck to be enabled."
+ " Call `xr.set_options(use_bottleneck=True)` to enable it."
+ )
+
if dim not in self.dims:
raise ValueError(f"Dataset does not contain the dimension: {dim}")
diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py
index 07864e81bb6..7f292605e63 100644
--- a/xarray/core/formatting.py
+++ b/xarray/core/formatting.py
@@ -387,12 +387,14 @@ def _mapping_repr(
elif len_mapping > max_rows:
summary = [f"{summary[0]} ({max_rows}/{len_mapping})"]
first_rows = max_rows // 2 + max_rows % 2
- items = list(mapping.items())
- summary += [summarizer(k, v, col_width) for k, v in items[:first_rows]]
+ keys = list(mapping.keys())
+ summary += [summarizer(k, mapping[k], col_width) for k in keys[:first_rows]]
if max_rows > 1:
last_rows = max_rows // 2
summary += [pretty_print(" ...", col_width) + " ..."]
- summary += [summarizer(k, v, col_width) for k, v in items[-last_rows:]]
+ summary += [
+ summarizer(k, mapping[k], col_width) for k in keys[-last_rows:]
+ ]
else:
summary += [summarizer(k, v, col_width) for k, v in mapping.items()]
else:
diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py
index 90d8eec6623..429c37af588 100644
--- a/xarray/core/indexes.py
+++ b/xarray/core/indexes.py
@@ -1,6 +1,4 @@
import collections.abc
-from contextlib import suppress
-from datetime import timedelta
from typing import (
TYPE_CHECKING,
Any,
@@ -18,28 +16,26 @@
import pandas as pd
from . import formatting, utils
-from .indexing import ExplicitlyIndexedNDArrayMixin, NumpyIndexingAdapter
-from .npcompat import DTypeLike
+from .indexing import (
+ LazilyIndexedArray,
+ PandasIndexingAdapter,
+ PandasMultiIndexingAdapter,
+)
from .utils import is_dict_like, is_scalar
if TYPE_CHECKING:
- from .variable import Variable
+ from .variable import IndexVariable, Variable
+
+IndexVars = Dict[Hashable, "IndexVariable"]
class Index:
"""Base class inherited by all xarray-compatible indexes."""
- __slots__ = ("coord_names",)
-
- def __init__(self, coord_names: Union[Hashable, Iterable[Hashable]]):
- if isinstance(coord_names, Hashable):
- coord_names = (coord_names,)
- self.coord_names = tuple(coord_names)
-
@classmethod
def from_variables(
- cls, variables: Dict[Hashable, "Variable"], **kwargs
- ): # pragma: no cover
+ cls, variables: Mapping[Hashable, "Variable"]
+ ) -> Tuple["Index", Optional[IndexVars]]: # pragma: no cover
raise NotImplementedError()
def to_pandas_index(self) -> pd.Index:
@@ -52,8 +48,10 @@ def to_pandas_index(self) -> pd.Index:
"""
raise TypeError(f"{type(self)} cannot be cast to a pandas.Index object.")
- def query(self, labels: Dict[Hashable, Any]): # pragma: no cover
- raise NotImplementedError
+ def query(
+ self, labels: Dict[Hashable, Any]
+ ) -> Tuple[Any, Optional[Tuple["Index", IndexVars]]]: # pragma: no cover
+ raise NotImplementedError()
def equals(self, other): # pragma: no cover
raise NotImplementedError()
@@ -64,6 +62,13 @@ def union(self, other): # pragma: no cover
def intersection(self, other): # pragma: no cover
raise NotImplementedError()
+ def copy(self, deep: bool = True): # pragma: no cover
+ raise NotImplementedError()
+
+ def __getitem__(self, indexer: Any):
+ # if not implemented, index will be dropped from the Dataset or DataArray
+ raise NotImplementedError()
+
def _sanitize_slice_element(x):
from .dataarray import DataArray
@@ -138,64 +143,68 @@ def get_indexer_nd(index, labels, method=None, tolerance=None):
return indexer
-class PandasIndex(Index, ExplicitlyIndexedNDArrayMixin):
- """Wrap a pandas.Index to preserve dtypes and handle explicit indexing."""
+class PandasIndex(Index):
+ """Wrap a pandas.Index as an xarray compatible index."""
- __slots__ = ("array", "_dtype")
+ __slots__ = ("index", "dim")
- def __init__(
- self, array: Any, dtype: DTypeLike = None, coord_name: Optional[Hashable] = None
- ):
- if coord_name is None:
- coord_name = tuple()
- super().__init__(coord_name)
+ def __init__(self, array: Any, dim: Hashable):
+ self.index = utils.safe_cast_to_index(array)
+ self.dim = dim
- self.array = utils.safe_cast_to_index(array)
+ @classmethod
+ def from_variables(cls, variables: Mapping[Hashable, "Variable"]):
+ from .variable import IndexVariable
- if dtype is None:
- if isinstance(array, pd.PeriodIndex):
- dtype_ = np.dtype("O")
- elif hasattr(array, "categories"):
- # category isn't a real numpy dtype
- dtype_ = array.categories.dtype
- elif not utils.is_valid_numpy_dtype(array.dtype):
- dtype_ = np.dtype("O")
- else:
- dtype_ = array.dtype
+ if len(variables) != 1:
+ raise ValueError(
+ f"PandasIndex only accepts one variable, found {len(variables)} variables"
+ )
+
+ name, var = next(iter(variables.items()))
+
+ if var.ndim != 1:
+ raise ValueError(
+ "PandasIndex only accepts a 1-dimensional variable, "
+ f"variable {name!r} has {var.ndim} dimensions"
+ )
+
+ dim = var.dims[0]
+
+ obj = cls(var.data, dim)
+
+ data = PandasIndexingAdapter(obj.index)
+ index_var = IndexVariable(
+ dim, data, attrs=var.attrs, encoding=var.encoding, fastpath=True
+ )
+
+ return obj, {name: index_var}
+
+ @classmethod
+ def from_pandas_index(cls, index: pd.Index, dim: Hashable):
+ from .variable import IndexVariable
+
+ if index.name is None:
+ name = dim
+ index = index.copy()
+ index.name = dim
else:
- dtype_ = np.dtype(dtype) # type: ignore[assignment]
- self._dtype = dtype_
+ name = index.name
+
+ data = PandasIndexingAdapter(index)
+ index_var = IndexVariable(dim, data, fastpath=True)
+
+ return cls(index, dim), {name: index_var}
def to_pandas_index(self) -> pd.Index:
- return self.array
-
- @property
- def dtype(self) -> np.dtype:
- return self._dtype
-
- def __array__(self, dtype: DTypeLike = None) -> np.ndarray:
- if dtype is None:
- dtype = self.dtype
- array = self.array
- if isinstance(array, pd.PeriodIndex):
- with suppress(AttributeError):
- # this might not be public API
- array = array.astype("object")
- return np.asarray(array.values, dtype=dtype)
-
- @property
- def shape(self) -> Tuple[int]:
- return (len(self.array),)
+ return self.index
- def query(
- self, labels, method=None, tolerance=None
- ) -> Tuple[Any, Union["PandasIndex", None]]:
+ def query(self, labels, method=None, tolerance=None):
assert len(labels) == 1
coord_name, label = next(iter(labels.items()))
- index = self.array
if isinstance(label, slice):
- indexer = _query_slice(index, label, coord_name, method, tolerance)
+ indexer = _query_slice(self.index, label, coord_name, method, tolerance)
elif is_dict_like(label):
raise ValueError(
"cannot use a dict-like object for selection on "
@@ -210,7 +219,7 @@ def query(
if label.ndim == 0:
# see https://github.com/pydata/xarray/pull/4292 for details
label_value = label[()] if label.dtype.kind in "mM" else label.item()
- if isinstance(index, pd.CategoricalIndex):
+ if isinstance(self.index, pd.CategoricalIndex):
if method is not None:
raise ValueError(
"'method' is not a valid kwarg when indexing using a CategoricalIndex."
@@ -219,115 +228,114 @@ def query(
raise ValueError(
"'tolerance' is not a valid kwarg when indexing using a CategoricalIndex."
)
- indexer = index.get_loc(label_value)
+ indexer = self.index.get_loc(label_value)
else:
- indexer = index.get_loc(
+ indexer = self.index.get_loc(
label_value, method=method, tolerance=tolerance
)
elif label.dtype.kind == "b":
indexer = label
else:
- indexer = get_indexer_nd(index, label, method, tolerance)
+ indexer = get_indexer_nd(self.index, label, method, tolerance)
if np.any(indexer < 0):
raise KeyError(f"not all values found in index {coord_name!r}")
return indexer, None
def equals(self, other):
- if isinstance(other, pd.Index):
- other = type(self)(other)
- return self.array.equals(other.array)
+ return self.index.equals(other.index)
def union(self, other):
- if isinstance(other, pd.Index):
- other = type(self)(other)
- return type(self)(self.array.union(other.array))
+ new_index = self.index.union(other.index)
+ return type(self)(new_index, self.dim)
def intersection(self, other):
- if isinstance(other, pd.Index):
- other = PandasIndex(other)
- return type(self)(self.array.intersection(other.array))
-
- def __getitem__(
- self, indexer
- ) -> Union[
- "PandasIndex",
- NumpyIndexingAdapter,
- np.ndarray,
- np.datetime64,
- np.timedelta64,
- ]:
- key = indexer.tuple
- if isinstance(key, tuple) and len(key) == 1:
- # unpack key so it can index a pandas.Index object (pandas.Index
- # objects don't like tuples)
- (key,) = key
-
- if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional
- return NumpyIndexingAdapter(self.array.values)[indexer]
-
- result = self.array[key]
-
- if isinstance(result, pd.Index):
- result = type(self)(result, dtype=self.dtype)
- else:
- # result is a scalar
- if result is pd.NaT:
- # work around the impossibility of casting NaT with asarray
- # note: it probably would be better in general to return
- # pd.Timestamp rather np.than datetime64 but this is easier
- # (for now)
- result = np.datetime64("NaT", "ns")
- elif isinstance(result, timedelta):
- result = np.timedelta64(getattr(result, "value", result), "ns")
- elif isinstance(result, pd.Timestamp):
- # Work around for GH: pydata/xarray#1932 and numpy/numpy#10668
- # numpy fails to convert pd.Timestamp to np.datetime64[ns]
- result = np.asarray(result.to_datetime64())
- elif self.dtype != object:
- result = np.asarray(result, dtype=self.dtype)
-
- # as for numpy.ndarray indexing, we always want the result to be
- # a NumPy array.
- result = utils.to_0d_array(result)
-
- return result
-
- def transpose(self, order) -> pd.Index:
- return self.array # self.array should be always one-dimensional
-
- def __repr__(self) -> str:
- return f"{type(self).__name__}(array={self.array!r}, dtype={self.dtype!r})"
-
- def copy(self, deep: bool = True) -> "PandasIndex":
- # Not the same as just writing `self.array.copy(deep=deep)`, as
- # shallow copies of the underlying numpy.ndarrays become deep ones
- # upon pickling
- # >>> len(pickle.dumps((self.array, self.array)))
- # 4000281
- # >>> len(pickle.dumps((self.array, self.array.copy(deep=False))))
- # 8000341
- array = self.array.copy(deep=True) if deep else self.array
- return type(self)(array, self._dtype)
+ new_index = self.index.intersection(other.index)
+ return type(self)(new_index, self.dim)
+
+ def copy(self, deep=True):
+ return type(self)(self.index.copy(deep=deep), self.dim)
+
+ def __getitem__(self, indexer: Any):
+ return type(self)(self.index[indexer], self.dim)
+
+
+def _create_variables_from_multiindex(index, dim, level_meta=None):
+ from .variable import IndexVariable
+
+ if level_meta is None:
+ level_meta = {}
+
+ variables = {}
+
+ dim_coord_adapter = PandasMultiIndexingAdapter(index)
+ variables[dim] = IndexVariable(
+ dim, LazilyIndexedArray(dim_coord_adapter), fastpath=True
+ )
+
+ for level in index.names:
+ meta = level_meta.get(level, {})
+ data = PandasMultiIndexingAdapter(
+ index, dtype=meta.get("dtype"), level=level, adapter=dim_coord_adapter
+ )
+ variables[level] = IndexVariable(
+ dim,
+ data,
+ attrs=meta.get("attrs"),
+ encoding=meta.get("encoding"),
+ fastpath=True,
+ )
+
+ return variables
class PandasMultiIndex(PandasIndex):
- def query(
- self, labels, method=None, tolerance=None
- ) -> Tuple[Any, Union["PandasIndex", None]]:
+ @classmethod
+ def from_variables(cls, variables: Mapping[Hashable, "Variable"]):
+ if any([var.ndim != 1 for var in variables.values()]):
+ raise ValueError("PandasMultiIndex only accepts 1-dimensional variables")
+
+ dims = set([var.dims for var in variables.values()])
+ if len(dims) != 1:
+ raise ValueError(
+ "unmatched dimensions for variables "
+ + ",".join([str(k) for k in variables])
+ )
+
+ dim = next(iter(dims))[0]
+ index = pd.MultiIndex.from_arrays(
+ [var.values for var in variables.values()], names=variables.keys()
+ )
+ obj = cls(index, dim)
+
+ level_meta = {
+ name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding}
+ for name, var in variables.items()
+ }
+ index_vars = _create_variables_from_multiindex(
+ index, dim, level_meta=level_meta
+ )
+
+ return obj, index_vars
+
+ @classmethod
+ def from_pandas_index(cls, index: pd.MultiIndex, dim: Hashable):
+ index_vars = _create_variables_from_multiindex(index, dim)
+ return cls(index, dim), index_vars
+
+ def query(self, labels, method=None, tolerance=None):
if method is not None or tolerance is not None:
raise ValueError(
"multi-index does not support ``method`` and ``tolerance``"
)
- index = self.array
new_index = None
# label(s) given for multi-index level(s)
- if all([lbl in index.names for lbl in labels]):
+ if all([lbl in self.index.names for lbl in labels]):
is_nested_vals = _is_nested_tuple(tuple(labels.values()))
- if len(labels) == index.nlevels and not is_nested_vals:
- indexer = index.get_loc(tuple(labels[k] for k in index.names))
+ if len(labels) == self.index.nlevels and not is_nested_vals:
+ indexer = self.index.get_loc(tuple(labels[k] for k in self.index.names))
else:
for k, v in labels.items():
# index should be an item (i.e. Hashable) not an array-like
@@ -336,7 +344,7 @@ def query(
"Vectorized selection is not "
f"available along coordinate {k!r} (multi-index level)"
)
- indexer, new_index = index.get_loc_level(
+ indexer, new_index = self.index.get_loc_level(
tuple(labels.values()), level=tuple(labels.keys())
)
# GH2619. Raise a KeyError if nothing is chosen
@@ -346,16 +354,18 @@ def query(
# assume one label value given for the multi-index "array" (dimension)
else:
if len(labels) > 1:
- coord_name = next(iter(set(labels) - set(index.names)))
+ coord_name = next(iter(set(labels) - set(self.index.names)))
raise ValueError(
f"cannot provide labels for both coordinate {coord_name!r} (multi-index array) "
- f"and one or more coordinates among {index.names!r} (multi-index levels)"
+ f"and one or more coordinates among {self.index.names!r} (multi-index levels)"
)
coord_name, label = next(iter(labels.items()))
if is_dict_like(label):
- invalid_levels = [name for name in label if name not in index.names]
+ invalid_levels = [
+ name for name in label if name not in self.index.names
+ ]
if invalid_levels:
raise ValueError(
f"invalid multi-index level names {invalid_levels}"
@@ -363,15 +373,15 @@ def query(
return self.query(label)
elif isinstance(label, slice):
- indexer = _query_slice(index, label, coord_name)
+ indexer = _query_slice(self.index, label, coord_name)
elif isinstance(label, tuple):
if _is_nested_tuple(label):
- indexer = index.get_locs(label)
- elif len(label) == index.nlevels:
- indexer = index.get_loc(label)
+ indexer = self.index.get_locs(label)
+ elif len(label) == self.index.nlevels:
+ indexer = self.index.get_loc(label)
else:
- indexer, new_index = index.get_loc_level(
+ indexer, new_index = self.index.get_loc_level(
label, level=list(range(len(label)))
)
@@ -382,7 +392,7 @@ def query(
else _asarray_tuplesafe(label)
)
if label.ndim == 0:
- indexer, new_index = index.get_loc_level(label.item(), level=0)
+ indexer, new_index = self.index.get_loc_level(label.item(), level=0)
elif label.dtype.kind == "b":
indexer = label
else:
@@ -391,21 +401,20 @@ def query(
"Vectorized selection is not available along "
f"coordinate {coord_name!r} with a multi-index"
)
- indexer = get_indexer_nd(index, label)
+ indexer = get_indexer_nd(self.index, label)
if np.any(indexer < 0):
raise KeyError(f"not all values found in index {coord_name!r}")
if new_index is not None:
- new_index = PandasIndex(new_index)
-
- return indexer, new_index
-
-
-def wrap_pandas_index(index):
- if isinstance(index, pd.MultiIndex):
- return PandasMultiIndex(index)
- else:
- return PandasIndex(index)
+ if isinstance(new_index, pd.MultiIndex):
+ new_index, new_vars = PandasMultiIndex.from_pandas_index(
+ new_index, self.dim
+ )
+ else:
+ new_index, new_vars = PandasIndex.from_pandas_index(new_index, self.dim)
+ return indexer, (new_index, new_vars)
+ else:
+ return indexer, None
def remove_unused_levels_categories(index: pd.Index) -> pd.Index:
@@ -492,7 +501,13 @@ def isel_variable_and_index(
index: Index,
indexers: Mapping[Hashable, Union[int, slice, np.ndarray, "Variable"]],
) -> Tuple["Variable", Optional[Index]]:
- """Index a Variable and pandas.Index together."""
+ """Index a Variable and an Index together.
+
+ If the index cannot be indexed, return None (it will be dropped).
+
+ (note: not compatible yet with xarray flexible indexes).
+
+ """
from .variable import Variable
if not indexers:
@@ -515,8 +530,11 @@ def isel_variable_and_index(
indexer = indexers[dim]
if isinstance(indexer, Variable):
indexer = indexer.data
- pd_index = index.to_pandas_index()
- new_index = wrap_pandas_index(pd_index[indexer])
+ try:
+ new_index = index[indexer]
+ except NotImplementedError:
+ new_index = None
+
return new_variable, new_index
@@ -528,7 +546,7 @@ def roll_index(index: PandasIndex, count: int, axis: int = 0) -> PandasIndex:
new_idx = pd_index[-count:].append(pd_index[:-count])
else:
new_idx = pd_index[:]
- return PandasIndex(new_idx)
+ return PandasIndex(new_idx, index.dim)
def propagate_indexes(
diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py
index 1ace4db241d..70994a36ac8 100644
--- a/xarray/core/indexing.py
+++ b/xarray/core/indexing.py
@@ -2,12 +2,15 @@
import functools
import operator
from collections import defaultdict
-from typing import Any, Callable, Iterable, List, Tuple, Union
+from contextlib import suppress
+from datetime import timedelta
+from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
from . import duck_array_ops, nputils, utils
+from .npcompat import DTypeLike
from .pycompat import (
dask_array_type,
dask_version,
@@ -569,9 +572,7 @@ def as_indexable(array):
if isinstance(array, np.ndarray):
return NumpyIndexingAdapter(array)
if isinstance(array, pd.Index):
- from .indexes import PandasIndex
-
- return PandasIndex(array)
+ return PandasIndexingAdapter(array)
if isinstance(array, dask_array_type):
return DaskIndexingAdapter(array)
if hasattr(array, "__array_function__"):
@@ -1259,3 +1260,149 @@ def __setitem__(self, key, value):
def transpose(self, order):
return self.array.transpose(order)
+
+
+class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
+ """Wrap a pandas.Index to preserve dtypes and handle explicit indexing."""
+
+ __slots__ = ("array", "_dtype")
+
+ def __init__(self, array: pd.Index, dtype: DTypeLike = None):
+ self.array = utils.safe_cast_to_index(array)
+
+ if dtype is None:
+ if isinstance(array, pd.PeriodIndex):
+ dtype_ = np.dtype("O")
+ elif hasattr(array, "categories"):
+ # category isn't a real numpy dtype
+ dtype_ = array.categories.dtype
+ elif not utils.is_valid_numpy_dtype(array.dtype):
+ dtype_ = np.dtype("O")
+ else:
+ dtype_ = array.dtype
+ else:
+ dtype_ = np.dtype(dtype) # type: ignore[assignment]
+ self._dtype = dtype_
+
+ @property
+ def dtype(self) -> np.dtype:
+ return self._dtype
+
+ def __array__(self, dtype: DTypeLike = None) -> np.ndarray:
+ if dtype is None:
+ dtype = self.dtype
+ array = self.array
+ if isinstance(array, pd.PeriodIndex):
+ with suppress(AttributeError):
+ # this might not be public API
+ array = array.astype("object")
+ return np.asarray(array.values, dtype=dtype)
+
+ @property
+ def shape(self) -> Tuple[int]:
+ return (len(self.array),)
+
+ def __getitem__(
+ self, indexer
+ ) -> Union[
+ "PandasIndexingAdapter",
+ NumpyIndexingAdapter,
+ np.ndarray,
+ np.datetime64,
+ np.timedelta64,
+ ]:
+ key = indexer.tuple
+ if isinstance(key, tuple) and len(key) == 1:
+ # unpack key so it can index a pandas.Index object (pandas.Index
+ # objects don't like tuples)
+ (key,) = key
+
+ if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional
+ return NumpyIndexingAdapter(self.array.values)[indexer]
+
+ result = self.array[key]
+
+ if isinstance(result, pd.Index):
+ result = type(self)(result, dtype=self.dtype)
+ else:
+ # result is a scalar
+ if result is pd.NaT:
+ # work around the impossibility of casting NaT with asarray
+ # note: it probably would be better in general to return
+ # pd.Timestamp rather np.than datetime64 but this is easier
+ # (for now)
+ result = np.datetime64("NaT", "ns")
+ elif isinstance(result, timedelta):
+ result = np.timedelta64(getattr(result, "value", result), "ns")
+ elif isinstance(result, pd.Timestamp):
+ # Work around for GH: pydata/xarray#1932 and numpy/numpy#10668
+ # numpy fails to convert pd.Timestamp to np.datetime64[ns]
+ result = np.asarray(result.to_datetime64())
+ elif self.dtype != object:
+ result = np.asarray(result, dtype=self.dtype)
+
+ # as for numpy.ndarray indexing, we always want the result to be
+ # a NumPy array.
+ result = utils.to_0d_array(result)
+
+ return result
+
+ def transpose(self, order) -> pd.Index:
+ return self.array # self.array should be always one-dimensional
+
+ def __repr__(self) -> str:
+ return f"{type(self).__name__}(array={self.array!r}, dtype={self.dtype!r})"
+
+ def copy(self, deep: bool = True) -> "PandasIndexingAdapter":
+ # Not the same as just writing `self.array.copy(deep=deep)`, as
+ # shallow copies of the underlying numpy.ndarrays become deep ones
+ # upon pickling
+ # >>> len(pickle.dumps((self.array, self.array)))
+ # 4000281
+ # >>> len(pickle.dumps((self.array, self.array.copy(deep=False))))
+ # 8000341
+ array = self.array.copy(deep=True) if deep else self.array
+ return type(self)(array, self._dtype)
+
+
+class PandasMultiIndexingAdapter(PandasIndexingAdapter):
+ """Handles explicit indexing for a pandas.MultiIndex.
+
+ This allows creating one instance for each multi-index level while
+ preserving indexing efficiency (memoized + might reuse another instance with
+ the same multi-index).
+
+ """
+
+ __slots__ = ("array", "_dtype", "level", "adapter")
+
+ def __init__(
+ self,
+ array: pd.MultiIndex,
+ dtype: DTypeLike = None,
+ level: Optional[str] = None,
+ adapter: Optional[PandasIndexingAdapter] = None,
+ ):
+ super().__init__(array, dtype)
+ self.level = level
+ self.adapter = adapter
+
+ def __array__(self, dtype: DTypeLike = None) -> np.ndarray:
+ if self.level is not None:
+ return self.array.get_level_values(self.level).values
+ else:
+ return super().__array__(dtype)
+
+ @functools.lru_cache(1)
+ def __getitem__(self, indexer):
+ if self.adapter is None:
+ return super().__getitem__(indexer)
+ else:
+ return self.adapter.__getitem__(indexer)
+
+ def __repr__(self) -> str:
+ if self.level is None:
+ return super().__repr__()
+ else:
+ props = "(array={self.array!r}, level={self.level!r}, dtype={self.dtype!r})"
+ return f"{type(self).__name__}{props}"
diff --git a/xarray/core/merge.py b/xarray/core/merge.py
index db5b95fd415..b8b32bdaa01 100644
--- a/xarray/core/merge.py
+++ b/xarray/core/merge.py
@@ -578,7 +578,7 @@ def merge_core(
combine_attrs: Optional[str] = "override",
priority_arg: Optional[int] = None,
explicit_coords: Optional[Sequence] = None,
- indexes: Optional[Mapping[Hashable, Index]] = None,
+ indexes: Optional[Mapping[Hashable, Any]] = None,
fill_value: object = dtypes.NA,
) -> _MergeResult:
"""Core logic for merging labeled objects.
@@ -601,7 +601,8 @@ def merge_core(
explicit_coords : set, optional
An explicit list of variables from `objects` that are coordinates.
indexes : dict, optional
- Dictionary with values given by pandas.Index objects.
+ Dictionary with values given by xarray.Index objects or anything that
+ may be cast to pandas.Index objects.
fill_value : scalar, optional
Value to use for newly missing values
@@ -979,8 +980,14 @@ def dataset_update_method(
other[key] = value.drop_vars(coord_names)
# use ds.coords and not ds.indexes, else str coords are cast to object
- # TODO: benbovy - flexible indexes: fix this (it only works with pandas indexes)
- indexes = {key: PandasIndex(dataset.coords[key]) for key in dataset.xindexes.keys()}
+ # TODO: benbovy - flexible indexes: make it work with any xarray index
+ indexes = {}
+ for key, index in dataset.xindexes.items():
+ if isinstance(index, PandasIndex):
+ indexes[key] = dataset.coords[key]
+ else:
+ indexes[key] = index
+
return merge_core(
[dataset, other],
priority_arg=1,
diff --git a/xarray/core/missing.py b/xarray/core/missing.py
index 6b5742104e4..36983a227b9 100644
--- a/xarray/core/missing.py
+++ b/xarray/core/missing.py
@@ -12,7 +12,7 @@
from .common import _contains_datetime_like_objects, ones_like
from .computation import apply_ufunc
from .duck_array_ops import datetime_to_numeric, push, timedelta_to_numeric
-from .options import _get_keep_attrs
+from .options import OPTIONS, _get_keep_attrs
from .pycompat import dask_version, is_duck_dask_array
from .utils import OrderedSet, is_scalar
from .variable import Variable, broadcast_variables
@@ -405,6 +405,12 @@ def _bfill(arr, n=None, axis=-1):
def ffill(arr, dim=None, limit=None):
"""forward fill missing values"""
+ if not OPTIONS["use_bottleneck"]:
+ raise RuntimeError(
+ "ffill requires bottleneck to be enabled."
+ " Call `xr.set_options(use_bottleneck=True)` to enable it."
+ )
+
axis = arr.get_axis_num(dim)
# work around for bottleneck 178
@@ -422,6 +428,12 @@ def ffill(arr, dim=None, limit=None):
def bfill(arr, dim=None, limit=None):
"""backfill missing values"""
+ if not OPTIONS["use_bottleneck"]:
+ raise RuntimeError(
+ "bfill requires bottleneck to be enabled."
+ " Call `xr.set_options(use_bottleneck=True)` to enable it."
+ )
+
axis = arr.get_axis_num(dim)
# work around for bottleneck 178
diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py
index 3aaed08575a..3e0f550dd30 100644
--- a/xarray/core/nputils.py
+++ b/xarray/core/nputils.py
@@ -4,6 +4,8 @@
import pandas as pd
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
+from .options import OPTIONS
+
try:
import bottleneck as bn
@@ -138,6 +140,7 @@ def f(values, axis=None, **kwargs):
if (
_USE_BOTTLENECK
+ and OPTIONS["use_bottleneck"]
and isinstance(values, np.ndarray)
and bn_func is not None
and not isinstance(axis, tuple)
diff --git a/xarray/core/options.py b/xarray/core/options.py
index 7104e12c29f..1d916ff0f7c 100644
--- a/xarray/core/options.py
+++ b/xarray/core/options.py
@@ -14,6 +14,7 @@
FILE_CACHE_MAXSIZE = "file_cache_maxsize"
KEEP_ATTRS = "keep_attrs"
WARN_FOR_UNCLOSED_FILES = "warn_for_unclosed_files"
+USE_BOTTLENECK = "use_bottleneck"
OPTIONS = {
@@ -31,6 +32,7 @@
FILE_CACHE_MAXSIZE: 128,
KEEP_ATTRS: "default",
WARN_FOR_UNCLOSED_FILES: False,
+ USE_BOTTLENECK: True,
}
_JOIN_OPTIONS = frozenset(["inner", "outer", "left", "right", "exact"])
@@ -54,6 +56,7 @@ def _positive_integer(value):
FILE_CACHE_MAXSIZE: _positive_integer,
KEEP_ATTRS: lambda choice: choice in [True, False, "default"],
WARN_FOR_UNCLOSED_FILES: lambda value: isinstance(value, bool),
+ USE_BOTTLENECK: lambda value: isinstance(value, bool),
}
@@ -122,6 +125,9 @@ class set_options:
attrs, ``False`` to always discard them, or ``'default'`` to use original
logic that attrs should only be kept in unambiguous circumstances.
Default: ``'default'``.
+ - ``use_bottleneck``: allow using bottleneck. Either ``True`` to accelerate
+ operations using bottleneck if it is installed or ``False`` to never use it.
+ Default: ``True``
- ``display_style``: display style to use in jupyter for xarray objects.
Default: ``'html'``. Other options are ``'text'``.
- ``display_expand_attrs``: whether to expand the attributes section for
diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py
index 795d30af28f..2c7f4249b5e 100644
--- a/xarray/core/parallel.py
+++ b/xarray/core/parallel.py
@@ -27,8 +27,6 @@
import numpy as np
-from xarray.core.indexes import PandasIndex
-
from .alignment import align
from .dataarray import DataArray
from .dataset import Dataset
@@ -295,9 +293,10 @@ def _wrapper(
# check that index lengths and values are as expected
for name, index in result.xindexes.items():
if name in expected["shapes"]:
- if len(index) != expected["shapes"][name]:
+ if result.sizes[name] != expected["shapes"][name]:
raise ValueError(
- f"Received dimension {name!r} of length {len(index)}. Expected length {expected['shapes'][name]}."
+ f"Received dimension {name!r} of length {result.sizes[name]}. "
+ f"Expected length {expected['shapes'][name]}."
)
if name in expected["indexes"]:
expected_index = expected["indexes"][name]
@@ -503,16 +502,10 @@ def subset_dataset_to_block(
}
expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment]
expected["coords"] = set(template.coords.keys()) # type: ignore[assignment]
- # TODO: benbovy - flexible indexes: clean this up
- # for now assumes pandas index (thus can be indexed) but it won't be the case for
- # all indexes
- expected_indexes = {}
- for dim in indexes:
- idx = indexes[dim].to_pandas_index()[
- _get_chunk_slicer(dim, chunk_index, output_chunk_bounds)
- ]
- expected_indexes[dim] = PandasIndex(idx)
- expected["indexes"] = expected_indexes
+ expected["indexes"] = {
+ dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)]
+ for dim in indexes
+ }
from_wrapper = (gname,) + chunk_tuple
graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected)
@@ -557,7 +550,13 @@ def subset_dataset_to_block(
},
)
- result = Dataset(coords=indexes, attrs=template.attrs)
+ # TODO: benbovy - flexible indexes: make it work with custom indexes
+ # this will need to pass both indexes and coords to the Dataset constructor
+ result = Dataset(
+ coords={k: idx.to_pandas_index() for k, idx in indexes.items()},
+ attrs=template.attrs,
+ )
+
for index in result.xindexes:
result[index].attrs = template[index].attrs
result[index].encoding = template[index].encoding
@@ -568,8 +567,8 @@ def subset_dataset_to_block(
for dim in dims:
if dim in output_chunks:
var_chunks.append(output_chunks[dim])
- elif dim in indexes:
- var_chunks.append((len(indexes[dim]),))
+ elif dim in result.xindexes:
+ var_chunks.append((result.sizes[dim],))
elif dim in template.dims:
# new unindexed dimension
var_chunks.append((template.sizes[dim],))
diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py
index 9f47da6c8cc..d1649235006 100644
--- a/xarray/core/pycompat.py
+++ b/xarray/core/pycompat.py
@@ -1,4 +1,5 @@
from distutils.version import LooseVersion
+from importlib import import_module
import numpy as np
@@ -6,42 +7,57 @@
integer_types = (int, np.integer)
-try:
- import dask
- import dask.array
- from dask.base import is_dask_collection
- dask_version = LooseVersion(dask.__version__)
+class DuckArrayModule:
+ """
+ Solely for internal isinstance and version checks.
- # solely for isinstance checks
- dask_array_type = (dask.array.Array,)
+ Motivated by having to only import pint when required (as pint currently imports xarray)
+ https://github.com/pydata/xarray/pull/5561#discussion_r664815718
+ """
- def is_duck_dask_array(x):
- return is_duck_array(x) and is_dask_collection(x)
+ def __init__(self, mod):
+ try:
+ duck_array_module = import_module(mod)
+ duck_array_version = LooseVersion(duck_array_module.__version__)
+
+ if mod == "dask":
+ duck_array_type = (import_module("dask.array").Array,)
+ elif mod == "pint":
+ duck_array_type = (duck_array_module.Quantity,)
+ elif mod == "cupy":
+ duck_array_type = (duck_array_module.ndarray,)
+ elif mod == "sparse":
+ duck_array_type = (duck_array_module.SparseArray,)
+ else:
+ raise NotImplementedError
+
+ except ImportError: # pragma: no cover
+ duck_array_module = None
+ duck_array_version = LooseVersion("0.0.0")
+ duck_array_type = ()
+ self.module = duck_array_module
+ self.version = duck_array_version
+ self.type = duck_array_type
+ self.available = duck_array_module is not None
-except ImportError: # pragma: no cover
- dask_version = LooseVersion("0.0.0")
- dask_array_type = ()
- is_duck_dask_array = lambda _: False
- is_dask_collection = lambda _: False
-try:
- # solely for isinstance checks
- import sparse
+def is_duck_dask_array(x):
+ if DuckArrayModule("dask").available:
+ from dask.base import is_dask_collection
+
+ return is_duck_array(x) and is_dask_collection(x)
+ else:
+ return False
+
- sparse_version = LooseVersion(sparse.__version__)
- sparse_array_type = (sparse.SparseArray,)
-except ImportError: # pragma: no cover
- sparse_version = LooseVersion("0.0.0")
- sparse_array_type = ()
+dsk = DuckArrayModule("dask")
+dask_version = dsk.version
+dask_array_type = dsk.type
-try:
- # solely for isinstance checks
- import cupy
+sp = DuckArrayModule("sparse")
+sparse_array_type = sp.type
+sparse_version = sp.version
- cupy_version = LooseVersion(cupy.__version__)
- cupy_array_type = (cupy.ndarray,)
-except ImportError: # pragma: no cover
- cupy_version = LooseVersion("0.0.0")
- cupy_array_type = ()
+cupy_array_type = DuckArrayModule("cupy").type
diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py
index b87dcda24b0..0cac9f2b129 100644
--- a/xarray/core/rolling.py
+++ b/xarray/core/rolling.py
@@ -7,7 +7,7 @@
from . import dtypes, duck_array_ops, utils
from .arithmetic import CoarsenArithmetic
-from .options import _get_keep_attrs
+from .options import OPTIONS, _get_keep_attrs
from .pycompat import is_duck_dask_array
from .utils import either_dict_or_kwargs
@@ -48,10 +48,10 @@ class Rolling:
xarray.DataArray.rolling
"""
- __slots__ = ("obj", "window", "min_periods", "center", "dim", "keep_attrs")
- _attributes = ("window", "min_periods", "center", "dim", "keep_attrs")
+ __slots__ = ("obj", "window", "min_periods", "center", "dim")
+ _attributes = ("window", "min_periods", "center", "dim")
- def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None):
+ def __init__(self, obj, windows, min_periods=None, center=False):
"""
Moving window object.
@@ -89,15 +89,6 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
self.min_periods = np.prod(self.window) if min_periods is None else min_periods
- if keep_attrs is not None:
- warnings.warn(
- "Passing ``keep_attrs`` to ``rolling`` is deprecated and will raise an"
- " error in xarray 0.18. Please pass ``keep_attrs`` directly to the"
- " applied function. Note that keep_attrs is now True per default.",
- FutureWarning,
- )
- self.keep_attrs = keep_attrs
-
def __repr__(self):
"""provide a nice str repr of our rolling object"""
@@ -188,15 +179,8 @@ def _mapping_to_list(
)
def _get_keep_attrs(self, keep_attrs):
-
if keep_attrs is None:
- # TODO: uncomment the next line and remove the others after the deprecation
- # keep_attrs = _get_keep_attrs(default=True)
-
- if self.keep_attrs is None:
- keep_attrs = _get_keep_attrs(default=True)
- else:
- keep_attrs = self.keep_attrs
+ keep_attrs = _get_keep_attrs(default=True)
return keep_attrs
@@ -204,7 +188,7 @@ def _get_keep_attrs(self, keep_attrs):
class DataArrayRolling(Rolling):
__slots__ = ("window_labels",)
- def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None):
+ def __init__(self, obj, windows, min_periods=None, center=False):
"""
Moving window object for DataArray.
You should use DataArray.rolling() method to construct this object
@@ -235,9 +219,7 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
xarray.Dataset.rolling
xarray.Dataset.groupby
"""
- super().__init__(
- obj, windows, min_periods=min_periods, center=center, keep_attrs=keep_attrs
- )
+ super().__init__(obj, windows, min_periods=min_periods, center=center)
# TODO legacy attribute
self.window_labels = self.obj[self.dim[0]]
@@ -535,7 +517,8 @@ def _numpy_or_bottleneck_reduce(
del kwargs["dim"]
if (
- bottleneck_move_func is not None
+ OPTIONS["use_bottleneck"]
+ and bottleneck_move_func is not None
and not is_duck_dask_array(self.obj.data)
and len(self.dim) == 1
):
@@ -561,7 +544,7 @@ def _numpy_or_bottleneck_reduce(
class DatasetRolling(Rolling):
__slots__ = ("rollings",)
- def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None):
+ def __init__(self, obj, windows, min_periods=None, center=False):
"""
Moving window object for Dataset.
You should use Dataset.rolling() method to construct this object
@@ -592,7 +575,7 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
xarray.Dataset.groupby
xarray.DataArray.groupby
"""
- super().__init__(obj, windows, min_periods, center, keep_attrs)
+ super().__init__(obj, windows, min_periods, center)
if any(d not in self.obj.dims for d in self.dim):
raise KeyError(self.dim)
# Keep each Rolling object as a dictionary
@@ -768,11 +751,10 @@ class Coarsen(CoarsenArithmetic):
"windows",
"side",
"trim_excess",
- "keep_attrs",
)
_attributes = ("windows", "side", "trim_excess")
- def __init__(self, obj, windows, boundary, side, coord_func, keep_attrs):
+ def __init__(self, obj, windows, boundary, side, coord_func):
"""
Moving window object.
@@ -799,17 +781,6 @@ def __init__(self, obj, windows, boundary, side, coord_func, keep_attrs):
self.side = side
self.boundary = boundary
- if keep_attrs is not None:
- warnings.warn(
- "Passing ``keep_attrs`` to ``coarsen`` is deprecated and will raise an"
- " error in xarray 0.19. Please pass ``keep_attrs`` directly to the"
- " applied function, i.e. use ``ds.coarsen(...).mean(keep_attrs=False)``"
- " instead of ``ds.coarsen(..., keep_attrs=False).mean()``"
- " Note that keep_attrs is now True per default.",
- FutureWarning,
- )
- self.keep_attrs = keep_attrs
-
absent_dims = [dim for dim in windows.keys() if dim not in self.obj.dims]
if absent_dims:
raise ValueError(
@@ -823,15 +794,8 @@ def __init__(self, obj, windows, boundary, side, coord_func, keep_attrs):
self.coord_func = coord_func
def _get_keep_attrs(self, keep_attrs):
-
if keep_attrs is None:
- # TODO: uncomment the next line and remove the others after the deprecation
- # keep_attrs = _get_keep_attrs(default=True)
-
- if self.keep_attrs is None:
- keep_attrs = _get_keep_attrs(default=True)
- else:
- keep_attrs = self.keep_attrs
+ keep_attrs = _get_keep_attrs(default=True)
return keep_attrs
diff --git a/xarray/core/utils.py b/xarray/core/utils.py
index 72e34932579..a139d2ef10a 100644
--- a/xarray/core/utils.py
+++ b/xarray/core/utils.py
@@ -10,6 +10,7 @@
import warnings
from enum import Enum
from typing import (
+ TYPE_CHECKING,
Any,
Callable,
Collection,
@@ -32,12 +33,6 @@
import numpy as np
import pandas as pd
-if sys.version_info >= (3, 10):
- from typing import TypeGuard
-else:
- from typing_extensions import TypeGuard
-
-
K = TypeVar("K")
V = TypeVar("V")
T = TypeVar("T")
@@ -297,11 +292,7 @@ def either_dict_or_kwargs(
return pos_kwargs
-def is_scalar(value: Any, include_0d: bool = True) -> TypeGuard[Hashable]:
- """Whether to treat a value as a scalar.
-
- Any non-iterable, string, or 0-D array
- """
+def _is_scalar(value, include_0d):
from .variable import NON_NUMPY_SUPPORTED_ARRAY_TYPES
if include_0d:
@@ -316,6 +307,37 @@ def is_scalar(value: Any, include_0d: bool = True) -> TypeGuard[Hashable]:
)
+# See GH5624, this is a convoluted way to allow type-checking to use `TypeGuard` without
+# requiring typing_extensions as a required dependency to _run_ the code (it is required
+# to type-check).
+try:
+ if sys.version_info >= (3, 10):
+ from typing import TypeGuard
+ else:
+ from typing_extensions import TypeGuard
+except ImportError:
+ if TYPE_CHECKING:
+ raise
+ else:
+
+ def is_scalar(value: Any, include_0d: bool = True) -> bool:
+ """Whether to treat a value as a scalar.
+
+ Any non-iterable, string, or 0-D array
+ """
+ return _is_scalar(value, include_0d)
+
+
+else:
+
+ def is_scalar(value: Any, include_0d: bool = True) -> TypeGuard[Hashable]:
+ """Whether to treat a value as a scalar.
+
+ Any non-iterable, string, or 0-D array
+ """
+ return _is_scalar(value, include_0d)
+
+
def is_valid_numpy_dtype(dtype: Any) -> bool:
try:
np.dtype(dtype)
diff --git a/xarray/core/variable.py b/xarray/core/variable.py
index 48535f08958..47ec265d15e 100644
--- a/xarray/core/variable.py
+++ b/xarray/core/variable.py
@@ -25,14 +25,22 @@
from . import common, dtypes, duck_array_ops, indexing, nputils, ops, utils
from .arithmetic import VariableArithmetic
from .common import AbstractArray
-from .indexes import PandasIndex, wrap_pandas_index
-from .indexing import BasicIndexer, OuterIndexer, VectorizedIndexer, as_indexable
-from .options import _get_keep_attrs
+from .indexes import PandasIndex, PandasMultiIndex
+from .indexing import (
+ BasicIndexer,
+ OuterIndexer,
+ PandasIndexingAdapter,
+ VectorizedIndexer,
+ as_indexable,
+)
+from .options import OPTIONS, _get_keep_attrs
from .pycompat import (
+ DuckArrayModule,
cupy_array_type,
dask_array_type,
integer_types,
is_duck_dask_array,
+ sparse_array_type,
)
from .utils import (
NdimSizeLenMixin,
@@ -116,14 +124,9 @@ def as_variable(obj, name=None) -> "Union[Variable, IndexVariable]":
obj = obj.copy(deep=False)
elif isinstance(obj, tuple):
if isinstance(obj[1], DataArray):
- # TODO: change into TypeError
- warnings.warn(
- (
- "Using a DataArray object to construct a variable is"
- " ambiguous, please extract the data using the .data property."
- " This will raise a TypeError in 0.19.0."
- ),
- DeprecationWarning,
+ raise TypeError(
+ "Using a DataArray object to construct a variable is"
+ " ambiguous, please extract the data using the .data property."
)
try:
obj = Variable(*obj)
@@ -173,11 +176,11 @@ def _maybe_wrap_data(data):
Put pandas.Index and numpy.ndarray arguments in adapter objects to ensure
they can be indexed properly.
- NumpyArrayAdapter, PandasIndex and LazilyIndexedArray should
+ NumpyArrayAdapter, PandasIndexingAdapter and LazilyIndexedArray should
all pass through unmodified.
"""
if isinstance(data, pd.Index):
- return wrap_pandas_index(data)
+ return PandasIndexingAdapter(data)
return data
@@ -259,7 +262,7 @@ def _as_array_or_item(data):
TODO: remove this (replace with np.asarray) once these issues are fixed
"""
- data = data.get() if isinstance(data, cupy_array_type) else np.asarray(data)
+ data = np.asarray(data)
if data.ndim == 0:
if data.dtype.kind == "M":
data = np.datetime64(data, "ns")
@@ -334,7 +337,9 @@ def nbytes(self):
@property
def _in_memory(self):
- return isinstance(self._data, (np.ndarray, np.number, PandasIndex)) or (
+ return isinstance(
+ self._data, (np.ndarray, np.number, PandasIndexingAdapter)
+ ) or (
isinstance(self._data, indexing.MemoryCachedArray)
and isinstance(self._data.array, indexing.NumpyIndexingAdapter)
)
@@ -542,7 +547,14 @@ def to_index_variable(self):
def _to_xindex(self):
# temporary function used internally as a replacement of to_index()
# returns an xarray Index instance instead of a pd.Index instance
- return wrap_pandas_index(self.to_index())
+ index_var = self.to_index_variable()
+ index = index_var.to_index()
+ dim = index_var.dims[0]
+
+ if isinstance(index, pd.MultiIndex):
+ return PandasMultiIndex(index, dim)
+ else:
+ return PandasIndex(index, dim)
def to_index(self):
"""Convert this variable to a pandas.Index"""
@@ -1069,6 +1081,30 @@ def chunk(self, chunks={}, name=None, lock=False):
return self._replace(data=data)
+ def to_numpy(self) -> np.ndarray:
+ """Coerces wrapped data to numpy and returns a numpy.ndarray"""
+ # TODO an entrypoint so array libraries can choose coercion method?
+ data = self.data
+
+ # TODO first attempt to call .to_numpy() once some libraries implement it
+ if isinstance(data, dask_array_type):
+ data = data.compute()
+ if isinstance(data, cupy_array_type):
+ data = data.get()
+ # pint has to be imported dynamically as pint imports xarray
+ pint_array_type = DuckArrayModule("pint").type
+ if isinstance(data, pint_array_type):
+ data = data.magnitude
+ if isinstance(data, sparse_array_type):
+ data = data.todense()
+ data = np.asarray(data)
+
+ return data
+
+ def as_numpy(self: VariableType) -> VariableType:
+ """Coerces wrapped data into a numpy array, returning a Variable."""
+ return self._replace(data=self.to_numpy())
+
def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA):
"""
use sparse-array as backend.
@@ -1378,7 +1414,11 @@ def roll(self, shifts=None, **shifts_kwargs):
result = result._roll_one_dim(dim, count)
return result
- def transpose(self, *dims) -> "Variable":
+ def transpose(
+ self,
+ *dims,
+ missing_dims: str = "raise",
+ ) -> "Variable":
"""Return a new Variable object with transposed dimensions.
Parameters
@@ -1386,6 +1426,12 @@ def transpose(self, *dims) -> "Variable":
*dims : str, optional
By default, reverse the dimensions. Otherwise, reorder the
dimensions to this order.
+ missing_dims : {"raise", "warn", "ignore"}, default: "raise"
+ What to do if dimensions that should be selected from are not present in the
+ Variable:
+ - "raise": raise an exception
+ - "warn": raise a warning, and ignore the missing dimensions
+ - "ignore": ignore the missing dimensions
Returns
-------
@@ -1404,7 +1450,9 @@ def transpose(self, *dims) -> "Variable":
"""
if len(dims) == 0:
dims = self.dims[::-1]
- dims = tuple(infix_dims(dims, self.dims))
+ else:
+ dims = tuple(infix_dims(dims, self.dims, missing_dims))
+
if len(dims) < 2 or dims == self.dims:
# no need to transpose if only one dimension
# or dims are in same order
@@ -2025,6 +2073,12 @@ def rank(self, dim, pct=False):
--------
Dataset.rank, DataArray.rank
"""
+ if not OPTIONS["use_bottleneck"]:
+ raise RuntimeError(
+ "rank requires bottleneck to be enabled."
+ " Call `xr.set_options(use_bottleneck=True)` to enable it."
+ )
+
import bottleneck as bn
data = self.data
@@ -2559,8 +2613,8 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
raise ValueError(f"{type(self).__name__} objects must be 1-dimensional")
# Unlike in Variable, always eagerly load values into memory
- if not isinstance(self._data, PandasIndex):
- self._data = PandasIndex(self._data)
+ if not isinstance(self._data, PandasIndexingAdapter):
+ self._data = PandasIndexingAdapter(self._data)
def __dask_tokenize__(self):
from dask.base import normalize_token
@@ -2895,7 +2949,7 @@ def assert_unique_multiindex_level_names(variables):
level_names = defaultdict(list)
all_level_names = set()
for var_name, var in variables.items():
- if isinstance(var._data, PandasIndex):
+ if isinstance(var._data, PandasIndexingAdapter):
idx_level_names = var.to_index_variable().level_names
if idx_level_names is not None:
for n in idx_level_names:
diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py
index da7d523d28f..e20b6568e79 100644
--- a/xarray/plot/plot.py
+++ b/xarray/plot/plot.py
@@ -7,17 +7,22 @@
Dataset.plot._____
"""
import functools
+from distutils.version import LooseVersion
import numpy as np
import pandas as pd
+from ..core.alignment import broadcast
from .facetgrid import _easy_facetgrid
from .utils import (
_add_colorbar,
+ _adjust_legend_subtitles,
_assert_valid_xy,
_ensure_plottable,
_infer_interval_breaks,
_infer_xy_labels,
+ _is_numeric,
+ _legend_add_subtitle,
_process_cmap_cbar_kwargs,
_rescale_imshow_rgb,
_resolve_intervals_1dplot,
@@ -26,8 +31,132 @@
get_axis,
import_matplotlib_pyplot,
label_from_attrs,
+ legend_elements,
)
+# copied from seaborn
+_MARKERSIZE_RANGE = np.array([18.0, 72.0])
+
+
+def _infer_scatter_metadata(darray, x, z, hue, hue_style, size):
+ def _determine_array(darray, name, array_style):
+ """Find and determine what type of array it is."""
+ array = darray[name]
+ array_is_numeric = _is_numeric(array.values)
+
+ if array_style is None:
+ array_style = "continuous" if array_is_numeric else "discrete"
+ elif array_style not in ["discrete", "continuous"]:
+ raise ValueError(
+ f"The style '{array_style}' is not valid, "
+ "valid options are None, 'discrete' or 'continuous'."
+ )
+
+ array_label = label_from_attrs(array)
+
+ return array, array_style, array_label
+
+ # Add nice looking labels:
+ out = dict(ylabel=label_from_attrs(darray))
+ out.update(
+ {
+ k: label_from_attrs(darray[v]) if v in darray.coords else None
+ for k, v in [("xlabel", x), ("zlabel", z)]
+ }
+ )
+
+ # Add styles and labels for the dataarrays:
+ for type_, a, style in [("hue", hue, hue_style), ("size", size, None)]:
+ tp, stl, lbl = f"{type_}", f"{type_}_style", f"{type_}_label"
+ if a:
+ out[tp], out[stl], out[lbl] = _determine_array(darray, a, style)
+ else:
+ out[tp], out[stl], out[lbl] = None, None, None
+
+ return out
+
+
+# copied from seaborn
+def _parse_size(data, norm, width):
+ """
+ Determine what type of data it is. Then normalize it to width.
+
+ If the data is categorical, normalize it to numbers.
+ """
+ plt = import_matplotlib_pyplot()
+
+ if data is None:
+ return None
+
+ data = data.values.ravel()
+
+ if not _is_numeric(data):
+ # Data is categorical.
+ # Use pd.unique instead of np.unique because that keeps
+ # the order of the labels:
+ levels = pd.unique(data)
+ numbers = np.arange(1, 1 + len(levels))
+ else:
+ levels = numbers = np.sort(np.unique(data))
+
+ min_width, max_width = width
+ # width_range = min_width, max_width
+
+ if norm is None:
+ norm = plt.Normalize()
+ elif isinstance(norm, tuple):
+ norm = plt.Normalize(*norm)
+ elif not isinstance(norm, plt.Normalize):
+ err = "``size_norm`` must be None, tuple, or Normalize object."
+ raise ValueError(err)
+
+ norm.clip = True
+ if not norm.scaled():
+ norm(np.asarray(numbers))
+ # limits = norm.vmin, norm.vmax
+
+ scl = norm(numbers)
+ widths = np.asarray(min_width + scl * (max_width - min_width))
+ if scl.mask.any():
+ widths[scl.mask] = 0
+ sizes = dict(zip(levels, widths))
+
+ return pd.Series(sizes)
+
+
+def _infer_scatter_data(
+ darray, x, z, hue, size, size_norm, size_mapping=None, size_range=(1, 10)
+):
+ # Broadcast together all the chosen variables:
+ to_broadcast = dict(y=darray)
+ to_broadcast.update(
+ {k: darray[v] for k, v in dict(x=x, z=z).items() if v is not None}
+ )
+ to_broadcast.update(
+ {k: darray[v] for k, v in dict(hue=hue, size=size).items() if v in darray.dims}
+ )
+ broadcasted = dict(zip(to_broadcast.keys(), broadcast(*(to_broadcast.values()))))
+
+ # Normalize hue and size and create lookup tables:
+ for type_, mapping, norm, width in [
+ ("hue", None, None, [0, 1]),
+ ("size", size_mapping, size_norm, size_range),
+ ]:
+ broadcasted_type = broadcasted.get(type_, None)
+ if broadcasted_type is not None:
+ if mapping is None:
+ mapping = _parse_size(broadcasted_type, norm, width)
+
+ broadcasted[type_] = broadcasted_type.copy(
+ data=np.reshape(
+ mapping.loc[broadcasted_type.values.ravel()].values,
+ broadcasted_type.shape,
+ )
+ )
+ broadcasted[f"{type_}_to_label"] = pd.Series(mapping.index, index=mapping)
+
+ return broadcasted
+
def _infer_line_data(darray, x, y, hue):
@@ -301,7 +430,7 @@ def line(
# Remove pd.Intervals if contained in xplt.values and/or yplt.values.
xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot(
- xplt.values, yplt.values, kwargs
+ xplt.to_numpy(), yplt.to_numpy(), kwargs
)
xlabel = label_from_attrs(xplt, extra=x_suffix)
ylabel = label_from_attrs(yplt, extra=y_suffix)
@@ -320,7 +449,7 @@ def line(
ax.set_title(darray._title_for_slice())
if darray.ndim == 2 and add_legend:
- ax.legend(handles=primitive, labels=list(hueplt.values), title=hue_label)
+ ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label)
# Rotate dates on xlabels
# Do this without calling autofmt_xdate so that x-axes ticks
@@ -422,7 +551,7 @@ def hist(
"""
ax = get_axis(figsize, size, aspect, ax)
- no_nan = np.ravel(darray.values)
+ no_nan = np.ravel(darray.to_numpy())
no_nan = no_nan[pd.notnull(no_nan)]
primitive = ax.hist(no_nan, **kwargs)
@@ -435,6 +564,291 @@ def hist(
return primitive
+def scatter(
+ darray,
+ *args,
+ row=None,
+ col=None,
+ figsize=None,
+ aspect=None,
+ size=None,
+ ax=None,
+ hue=None,
+ hue_style=None,
+ x=None,
+ z=None,
+ xincrease=None,
+ yincrease=None,
+ xscale=None,
+ yscale=None,
+ xticks=None,
+ yticks=None,
+ xlim=None,
+ ylim=None,
+ add_legend=None,
+ add_colorbar=None,
+ cbar_kwargs=None,
+ cbar_ax=None,
+ vmin=None,
+ vmax=None,
+ norm=None,
+ infer_intervals=None,
+ center=None,
+ levels=None,
+ robust=None,
+ colors=None,
+ extend=None,
+ cmap=None,
+ _labels=True,
+ **kwargs,
+):
+ """
+ Scatter plot a DataArray along some coordinates.
+
+ Parameters
+ ----------
+ darray : DataArray
+ Dataarray to plot.
+ x, y : str
+ Variable names for x, y axis.
+ hue: str, optional
+ Variable by which to color scattered points
+ hue_style: str, optional
+ Can be either 'discrete' (legend) or 'continuous' (color bar).
+ markersize: str, optional
+ scatter only. Variable by which to vary size of scattered points.
+ size_norm: optional
+ Either None or 'Norm' instance to normalize the 'markersize' variable.
+ add_guide: bool, optional
+ Add a guide that depends on hue_style
+ - for "discrete", build a legend.
+ This is the default for non-numeric `hue` variables.
+ - for "continuous", build a colorbar
+ row : str, optional
+ If passed, make row faceted plots on this dimension name
+ col : str, optional
+ If passed, make column faceted plots on this dimension name
+ col_wrap : int, optional
+ Use together with ``col`` to wrap faceted plots
+ ax : matplotlib axes object, optional
+ If None, uses the current axis. Not applicable when using facets.
+ subplot_kws : dict, optional
+ Dictionary of keyword arguments for matplotlib subplots. Only applies
+ to FacetGrid plotting.
+ aspect : scalar, optional
+ Aspect ratio of plot, so that ``aspect * size`` gives the width in
+ inches. Only used if a ``size`` is provided.
+ size : scalar, optional
+ If provided, create a new figure for the plot with the given size.
+ Height (in inches) of each plot. See also: ``aspect``.
+ norm : ``matplotlib.colors.Normalize`` instance, optional
+ If the ``norm`` has vmin or vmax specified, the corresponding kwarg
+ must be None.
+ vmin, vmax : float, optional
+ Values to anchor the colormap, otherwise they are inferred from the
+ data and other keyword arguments. When a diverging dataset is inferred,
+ setting one of these values will fix the other by symmetry around
+ ``center``. Setting both values prevents use of a diverging colormap.
+ If discrete levels are provided as an explicit list, both of these
+ values are ignored.
+ cmap : str or colormap, optional
+ The mapping from data values to color space. Either a
+ matplotlib colormap name or object. If not provided, this will
+ be either ``viridis`` (if the function infers a sequential
+ dataset) or ``RdBu_r`` (if the function infers a diverging
+ dataset). When `Seaborn` is installed, ``cmap`` may also be a
+ `seaborn` color palette. If ``cmap`` is seaborn color palette
+ and the plot type is not ``contour`` or ``contourf``, ``levels``
+ must also be specified.
+ colors : color-like or list of color-like, optional
+ A single color or a list of colors. If the plot type is not ``contour``
+ or ``contourf``, the ``levels`` argument is required.
+ center : float, optional
+ The value at which to center the colormap. Passing this value implies
+ use of a diverging colormap. Setting it to ``False`` prevents use of a
+ diverging colormap.
+ robust : bool, optional
+ If True and ``vmin`` or ``vmax`` are absent, the colormap range is
+ computed with 2nd and 98th percentiles instead of the extreme values.
+ extend : {"neither", "both", "min", "max"}, optional
+ How to draw arrows extending the colorbar beyond its limits. If not
+ provided, extend is inferred from vmin, vmax and the data limits.
+ levels : int or list-like object, optional
+ Split the colormap (cmap) into discrete color intervals. If an integer
+ is provided, "nice" levels are chosen based on the data range: this can
+ imply that the final number of levels is not exactly the expected one.
+ Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to
+ setting ``levels=np.linspace(vmin, vmax, N)``.
+ **kwargs : optional
+ Additional keyword arguments to matplotlib
+ """
+ plt = import_matplotlib_pyplot()
+
+ # Handle facetgrids first
+ if row or col:
+ allargs = locals().copy()
+ allargs.update(allargs.pop("kwargs"))
+ allargs.pop("darray")
+ subplot_kws = dict(projection="3d") if z is not None else None
+ return _easy_facetgrid(
+ darray, scatter, kind="dataarray", subplot_kws=subplot_kws, **allargs
+ )
+
+ # Further
+ _is_facetgrid = kwargs.pop("_is_facetgrid", False)
+ if _is_facetgrid:
+ # Why do I need to pop these here?
+ kwargs.pop("y", None)
+ kwargs.pop("args", None)
+ kwargs.pop("add_labels", None)
+
+ _sizes = kwargs.pop("markersize", kwargs.pop("linewidth", None))
+ size_norm = kwargs.pop("size_norm", None)
+ size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid
+ cmap_params = kwargs.pop("cmap_params", None)
+
+ figsize = kwargs.pop("figsize", None)
+ subplot_kws = dict()
+ if z is not None and ax is None:
+ # TODO: Importing Axes3D is not necessary in matplotlib >= 3.2.
+ # Remove when minimum requirement of matplotlib is 3.2:
+ from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa
+
+ subplot_kws.update(projection="3d")
+ ax = get_axis(figsize, size, aspect, ax, **subplot_kws)
+ # Using 30, 30 minimizes rotation of the plot. Making it easier to
+ # build on your intuition from 2D plots:
+ if LooseVersion(plt.matplotlib.__version__) < "3.5.0":
+ ax.view_init(azim=30, elev=30)
+ else:
+ # https://github.com/matplotlib/matplotlib/pull/19873
+ ax.view_init(azim=30, elev=30, vertical_axis="y")
+ else:
+ ax = get_axis(figsize, size, aspect, ax, **subplot_kws)
+
+ _data = _infer_scatter_metadata(darray, x, z, hue, hue_style, _sizes)
+
+ add_guide = kwargs.pop("add_guide", None)
+ if add_legend is not None:
+ pass
+ elif add_guide is None or add_guide is True:
+ add_legend = True if _data["hue_style"] == "discrete" else False
+ elif add_legend is None:
+ add_legend = False
+
+ if add_colorbar is not None:
+ pass
+ elif add_guide is None or add_guide is True:
+ add_colorbar = True if _data["hue_style"] == "continuous" else False
+ else:
+ add_colorbar = False
+
+ # need to infer size_mapping with full dataset
+ _data.update(
+ _infer_scatter_data(
+ darray,
+ x,
+ z,
+ hue,
+ _sizes,
+ size_norm,
+ size_mapping,
+ _MARKERSIZE_RANGE,
+ )
+ )
+
+ cmap_params_subset = {}
+ if _data["hue"] is not None:
+ kwargs.update(c=_data["hue"].values.ravel())
+ cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
+ scatter, _data["hue"].values, **locals()
+ )
+
+ # subset that can be passed to scatter, hist2d
+ cmap_params_subset = {
+ vv: cmap_params[vv] for vv in ["vmin", "vmax", "norm", "cmap"]
+ }
+
+ if _data["size"] is not None:
+ kwargs.update(s=_data["size"].values.ravel())
+
+ if LooseVersion(plt.matplotlib.__version__) < "3.5.0":
+ # Plot the data. 3d plots has the z value in upward direction
+ # instead of y. To make jumping between 2d and 3d easy and intuitive
+ # switch the order so that z is shown in the depthwise direction:
+ axis_order = ["x", "z", "y"]
+ else:
+ # Switching axis order not needed in 3.5.0, can also simplify the code
+ # that uses axis_order:
+ # https://github.com/matplotlib/matplotlib/pull/19873
+ axis_order = ["x", "y", "z"]
+
+ primitive = ax.scatter(
+ *[
+ _data[v].values.ravel()
+ for v in axis_order
+ if _data.get(v, None) is not None
+ ],
+ **cmap_params_subset,
+ **kwargs,
+ )
+
+ # Set x, y, z labels:
+ i = 0
+ set_label = [ax.set_xlabel, ax.set_ylabel, getattr(ax, "set_zlabel", None)]
+ for v in axis_order:
+ if _data.get(f"{v}label", None) is not None:
+ set_label[i](_data[f"{v}label"])
+ i += 1
+
+ if add_legend:
+
+ def to_label(data, key, x):
+ """Map prop values back to its original values."""
+ if key in data:
+ # Use reindex to be less sensitive to float errors.
+ # Return as numpy array since legend_elements
+ # seems to require that:
+ return data[key].reindex(x, method="nearest").to_numpy()
+ else:
+ return x
+
+ handles, labels = [], []
+ for subtitle, prop, func in [
+ (
+ _data["hue_label"],
+ "colors",
+ functools.partial(to_label, _data, "hue_to_label"),
+ ),
+ (
+ _data["size_label"],
+ "sizes",
+ functools.partial(to_label, _data, "size_to_label"),
+ ),
+ ]:
+ if subtitle:
+ # Get legend handles and labels that displays the
+ # values correctly. Order might be different because
+ # legend_elements uses np.unique instead of pd.unique,
+ # FacetGrid.add_legend might have troubles with this:
+ hdl, lbl = legend_elements(primitive, prop, num="auto", func=func)
+ hdl, lbl = _legend_add_subtitle(hdl, lbl, subtitle, ax.scatter)
+ handles += hdl
+ labels += lbl
+ legend = ax.legend(handles, labels, framealpha=0.5)
+ _adjust_legend_subtitles(legend)
+
+ if add_colorbar and _data["hue_label"]:
+ if _data["hue_style"] == "discrete":
+ raise NotImplementedError("Cannot create a colorbar for non numerics.")
+ cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs
+ if "label" not in cbar_kwargs:
+ cbar_kwargs["label"] = _data["hue_label"]
+ _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params)
+
+ return primitive
+
+
# MUST run before any 2d plotting functions are defined since
# _plot2d decorator adds them as methods here.
class _PlotMethods:
@@ -468,6 +882,10 @@ def line(self, *args, **kwargs):
def step(self, *args, **kwargs):
return step(self._da, *args, **kwargs)
+ @functools.wraps(scatter)
+ def _scatter(self, *args, **kwargs):
+ return scatter(self._da, *args, **kwargs)
+
def override_signature(f):
def wrapper(func):
@@ -735,8 +1153,8 @@ def newplotfunc(
dims = (yval.dims[0], xval.dims[0])
# better to pass the ndarrays directly to plotting functions
- xval = xval.values
- yval = yval.values
+ xval = xval.to_numpy()
+ yval = yval.to_numpy()
# May need to transpose for correct x, y labels
# xlab may be the name of a coord, we have to check for dim names
diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py
index 85f9c8c5a86..f2f296096a5 100644
--- a/xarray/plot/utils.py
+++ b/xarray/plot/utils.py
@@ -9,6 +9,7 @@
import pandas as pd
from ..core.options import OPTIONS
+from ..core.pycompat import DuckArrayModule
from ..core.utils import is_scalar
try:
@@ -474,12 +475,20 @@ def label_from_attrs(da, extra=""):
else:
name = ""
- if da.attrs.get("units"):
- units = " [{}]".format(da.attrs["units"])
- elif da.attrs.get("unit"):
- units = " [{}]".format(da.attrs["unit"])
+ def _get_units_from_attrs(da):
+ if da.attrs.get("units"):
+ units = " [{}]".format(da.attrs["units"])
+ elif da.attrs.get("unit"):
+ units = " [{}]".format(da.attrs["unit"])
+ else:
+ units = ""
+ return units
+
+ pint_array_type = DuckArrayModule("pint").type
+ if isinstance(da.data, pint_array_type):
+ units = " [{}]".format(str(da.data.units))
else:
- units = ""
+ units = _get_units_from_attrs(da)
return "\n".join(textwrap.wrap(name + extra + units, 30))
@@ -896,6 +905,234 @@ def _get_nice_quiver_magnitude(u, v):
import matplotlib as mpl
ticker = mpl.ticker.MaxNLocator(3)
- mean = np.mean(np.hypot(u.values, v.values))
+ mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy()))
magnitude = ticker.tick_values(0, mean)[-2]
return magnitude
+
+
+# Copied from matplotlib, tweaked so func can return strings.
+# https://github.com/matplotlib/matplotlib/issues/19555
+def legend_elements(
+ self, prop="colors", num="auto", fmt=None, func=lambda x: x, **kwargs
+):
+ """
+ Create legend handles and labels for a PathCollection.
+
+ Each legend handle is a `.Line2D` representing the Path that was drawn,
+ and each label is a string what each Path represents.
+
+ This is useful for obtaining a legend for a `~.Axes.scatter` plot;
+ e.g.::
+
+ scatter = plt.scatter([1, 2, 3], [4, 5, 6], c=[7, 2, 3])
+ plt.legend(*scatter.legend_elements())
+
+ creates three legend elements, one for each color with the numerical
+ values passed to *c* as the labels.
+
+ Also see the :ref:`automatedlegendcreation` example.
+
+
+ Parameters
+ ----------
+ prop : {"colors", "sizes"}, default: "colors"
+ If "colors", the legend handles will show the different colors of
+ the collection. If "sizes", the legend will show the different
+ sizes. To set both, use *kwargs* to directly edit the `.Line2D`
+ properties.
+ num : int, None, "auto" (default), array-like, or `~.ticker.Locator`
+ Target number of elements to create.
+ If None, use all unique elements of the mappable array. If an
+ integer, target to use *num* elements in the normed range.
+ If *"auto"*, try to determine which option better suits the nature
+ of the data.
+ The number of created elements may slightly deviate from *num* due
+ to a `~.ticker.Locator` being used to find useful locations.
+ If a list or array, use exactly those elements for the legend.
+ Finally, a `~.ticker.Locator` can be provided.
+ fmt : str, `~matplotlib.ticker.Formatter`, or None (default)
+ The format or formatter to use for the labels. If a string must be
+ a valid input for a `~.StrMethodFormatter`. If None (the default),
+ use a `~.ScalarFormatter`.
+ func : function, default: ``lambda x: x``
+ Function to calculate the labels. Often the size (or color)
+ argument to `~.Axes.scatter` will have been pre-processed by the
+ user using a function ``s = f(x)`` to make the markers visible;
+ e.g. ``size = np.log10(x)``. Providing the inverse of this
+ function here allows that pre-processing to be inverted, so that
+ the legend labels have the correct values; e.g. ``func = lambda
+ x: 10**x``.
+ **kwargs
+ Allowed keyword arguments are *color* and *size*. E.g. it may be
+ useful to set the color of the markers if *prop="sizes"* is used;
+ similarly to set the size of the markers if *prop="colors"* is
+ used. Any further parameters are passed onto the `.Line2D`
+ instance. This may be useful to e.g. specify a different
+ *markeredgecolor* or *alpha* for the legend handles.
+
+ Returns
+ -------
+ handles : list of `.Line2D`
+ Visual representation of each element of the legend.
+ labels : list of str
+ The string labels for elements of the legend.
+ """
+ import warnings
+
+ import matplotlib as mpl
+
+ mlines = mpl.lines
+
+ handles = []
+ labels = []
+
+ if prop == "colors":
+ arr = self.get_array()
+ if arr is None:
+ warnings.warn(
+ "Collection without array used. Make sure to "
+ "specify the values to be colormapped via the "
+ "`c` argument."
+ )
+ return handles, labels
+ _size = kwargs.pop("size", mpl.rcParams["lines.markersize"])
+
+ def _get_color_and_size(value):
+ return self.cmap(self.norm(value)), _size
+
+ elif prop == "sizes":
+ arr = self.get_sizes()
+ _color = kwargs.pop("color", "k")
+
+ def _get_color_and_size(value):
+ return _color, np.sqrt(value)
+
+ else:
+ raise ValueError(
+ "Valid values for `prop` are 'colors' or "
+ f"'sizes'. You supplied '{prop}' instead."
+ )
+
+ # Get the unique values and their labels:
+ values = np.unique(arr)
+ label_values = np.asarray(func(values))
+ label_values_are_numeric = np.issubdtype(label_values.dtype, np.number)
+
+ # Handle the label format:
+ if fmt is None and label_values_are_numeric:
+ fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True)
+ elif fmt is None and not label_values_are_numeric:
+ fmt = mpl.ticker.StrMethodFormatter("{x}")
+ elif isinstance(fmt, str):
+ fmt = mpl.ticker.StrMethodFormatter(fmt)
+ fmt.create_dummy_axis()
+
+ if num == "auto":
+ num = 9
+ if len(values) <= num:
+ num = None
+
+ if label_values_are_numeric:
+ label_values_min = label_values.min()
+ label_values_max = label_values.max()
+ fmt.set_bounds(label_values_min, label_values_max)
+
+ if num is not None:
+ # Labels are numerical but larger than the target
+ # number of elements, reduce to target using matplotlibs
+ # ticker classes:
+ if isinstance(num, mpl.ticker.Locator):
+ loc = num
+ elif np.iterable(num):
+ loc = mpl.ticker.FixedLocator(num)
+ else:
+ num = int(num)
+ loc = mpl.ticker.MaxNLocator(
+ nbins=num, min_n_ticks=num - 1, steps=[1, 2, 2.5, 3, 5, 6, 8, 10]
+ )
+
+ # Get nicely spaced label_values:
+ label_values = loc.tick_values(label_values_min, label_values_max)
+
+ # Remove extrapolated label_values:
+ cond = (label_values >= label_values_min) & (
+ label_values <= label_values_max
+ )
+ label_values = label_values[cond]
+
+ # Get the corresponding values by creating a linear interpolant
+ # with small step size:
+ values_interp = np.linspace(values.min(), values.max(), 256)
+ label_values_interp = func(values_interp)
+ ix = np.argsort(label_values_interp)
+ values = np.interp(label_values, label_values_interp[ix], values_interp[ix])
+ elif num is not None and not label_values_are_numeric:
+ # Labels are not numerical so modifying label_values is not
+ # possible, instead filter the array with nicely distributed
+ # indexes:
+ if type(num) == int:
+ loc = mpl.ticker.LinearLocator(num)
+ else:
+ raise ValueError("`num` only supports integers for non-numeric labels.")
+
+ ind = loc.tick_values(0, len(label_values) - 1).astype(int)
+ label_values = label_values[ind]
+ values = values[ind]
+
+ # Some formatters requires set_locs:
+ if hasattr(fmt, "set_locs"):
+ fmt.set_locs(label_values)
+
+ # Default settings for handles, add or override with kwargs:
+ kw = dict(markeredgewidth=self.get_linewidths()[0], alpha=self.get_alpha())
+ kw.update(kwargs)
+
+ for val, lab in zip(values, label_values):
+ color, size = _get_color_and_size(val)
+ h = mlines.Line2D(
+ [0], [0], ls="", color=color, ms=size, marker=self.get_paths()[0], **kw
+ )
+ handles.append(h)
+ labels.append(fmt(lab))
+
+ return handles, labels
+
+
+def _legend_add_subtitle(handles, labels, text, func):
+ """Add a subtitle to legend handles."""
+ if text and len(handles) > 1:
+ # Create a blank handle that's not visible, the
+ # invisibillity will be used to discern which are subtitles
+ # or not:
+ blank_handle = func([], [], label=text)
+ blank_handle.set_visible(False)
+
+ # Subtitles are shown first:
+ handles = [blank_handle] + handles
+ labels = [text] + labels
+
+ return handles, labels
+
+
+def _adjust_legend_subtitles(legend):
+ """Make invisible-handle "subtitles" entries look more like titles."""
+ plt = import_matplotlib_pyplot()
+
+ # Legend title not in rcParams until 3.0
+ font_size = plt.rcParams.get("legend.title_fontsize", None)
+ hpackers = legend.findobj(plt.matplotlib.offsetbox.VPacker)[0].get_children()
+ for hpack in hpackers:
+ draw_area, text_area = hpack.get_children()
+ handles = draw_area.get_children()
+
+ # Assume that all artists that are not visible are
+ # subtitles:
+ if not all(artist.get_visible() for artist in handles):
+ # Remove the dummy marker which will bring the text
+ # more to the center:
+ draw_area.set_width(0)
+ for text in text_area.get_children():
+ if font_size is not None:
+ # The sutbtitles should have the same font size
+ # as normal legend titles:
+ text.set_size(font_size)
diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py
index 9029dc1c621..d757fb451cc 100644
--- a/xarray/tests/__init__.py
+++ b/xarray/tests/__init__.py
@@ -83,6 +83,7 @@ def LooseVersion(vstring):
has_numbagg, requires_numbagg = _importorskip("numbagg")
has_seaborn, requires_seaborn = _importorskip("seaborn")
has_sparse, requires_sparse = _importorskip("sparse")
+has_cupy, requires_cupy = _importorskip("cupy")
has_cartopy, requires_cartopy = _importorskip("cartopy")
# Need Pint 0.15 for __dask_tokenize__ tests for Quantity wrapped Dask Arrays
has_pint_0_15, requires_pint_0_15 = _importorskip("pint", minversion="0.15")
diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py
index 5079cd390f1..3bbc2c93b31 100644
--- a/xarray/tests/test_backends.py
+++ b/xarray/tests/test_backends.py
@@ -42,7 +42,7 @@
from xarray.backends.scipy_ import ScipyBackendEntrypoint
from xarray.coding.variables import SerializationWarning
from xarray.conventions import encode_dataset_coordinates
-from xarray.core import indexes, indexing
+from xarray.core import indexing
from xarray.core.options import set_options
from xarray.core.pycompat import dask_array_type
from xarray.tests import LooseVersion, mock
@@ -87,12 +87,8 @@
try:
import dask
import dask.array as da
-
- dask_version = dask.__version__
except ImportError:
- # needed for xfailed tests when dask < 2.4.0
- # remove when min dask > 2.4.0
- dask_version = "10.0"
+ pass
ON_WINDOWS = sys.platform == "win32"
default_value = object()
@@ -742,7 +738,7 @@ def find_and_validate_array(obj):
elif isinstance(obj.array, dask_array_type):
assert isinstance(obj, indexing.DaskIndexingAdapter)
elif isinstance(obj.array, pd.Index):
- assert isinstance(obj, indexes.PandasIndex)
+ assert isinstance(obj, indexing.PandasIndexingAdapter)
else:
raise TypeError(
"{} is wrapped by {}".format(type(obj.array), type(obj))
@@ -1961,7 +1957,6 @@ def test_hidden_zarr_keys(self):
with xr.decode_cf(store):
pass
- @pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334")
@pytest.mark.parametrize("group", [None, "group1"])
def test_write_persistence_modes(self, group):
original = create_test_data()
@@ -2039,7 +2034,6 @@ def test_encoding_kwarg_fixed_width_string(self):
def test_dataset_caching(self):
super().test_dataset_caching()
- @pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334")
def test_append_write(self):
super().test_append_write()
@@ -2122,7 +2116,6 @@ def test_check_encoding_is_consistent_after_append(self):
xr.concat([ds, ds_to_append], dim="time"),
)
- @pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334")
def test_append_with_new_variable(self):
ds, ds_to_append, ds_with_new_var = create_append_test_data()
@@ -2777,6 +2770,7 @@ def test_dump_encodings_h5py(self):
@requires_h5netcdf
+@requires_netCDF4
class TestH5NetCDFAlreadyOpen:
def test_open_dataset_group(self):
import h5netcdf
@@ -2861,6 +2855,7 @@ def test_open_twice(self):
with open_dataset(f, engine="h5netcdf"):
pass
+ @requires_scipy
def test_open_fileobj(self):
# open in-memory datasets instead of local file paths
expected = create_test_data().drop_vars("dim3")
@@ -5162,11 +5157,12 @@ def test_open_fsspec():
@requires_h5netcdf
+@requires_netCDF4
def test_load_single_value_h5netcdf(tmp_path):
"""Test that numeric single-element vector attributes are handled fine.
At present (h5netcdf v0.8.1), the h5netcdf exposes single-valued numeric variable
- attributes as arrays of length 1, as oppesed to scalars for the NetCDF4
+ attributes as arrays of length 1, as opposed to scalars for the NetCDF4
backend. This was leading to a ValueError upon loading a single value from
a file, see #4471. Test that loading causes no failure.
"""
diff --git a/xarray/tests/test_coarsen.py b/xarray/tests/test_coarsen.py
index 503c742252a..278a961166f 100644
--- a/xarray/tests/test_coarsen.py
+++ b/xarray/tests/test_coarsen.py
@@ -153,39 +153,6 @@ def test_coarsen_keep_attrs(funcname, argument):
assert result.da_not_coarsend.name == "da_not_coarsend"
-def test_coarsen_keep_attrs_deprecated():
- global_attrs = {"units": "test", "long_name": "testing"}
- attrs_da = {"da_attr": "test"}
-
- data = np.linspace(10, 15, 100)
- coords = np.linspace(1, 10, 100)
-
- ds = Dataset(
- data_vars={"da": ("coord", data)},
- coords={"coord": coords},
- attrs=global_attrs,
- )
- ds.da.attrs = attrs_da
-
- # deprecated option
- with pytest.warns(
- FutureWarning, match="Passing ``keep_attrs`` to ``coarsen`` is deprecated"
- ):
- result = ds.coarsen(dim={"coord": 5}, keep_attrs=False).mean()
-
- assert result.attrs == {}
- assert result.da.attrs == {}
-
- # the keep_attrs in the reduction function takes precedence
- with pytest.warns(
- FutureWarning, match="Passing ``keep_attrs`` to ``coarsen`` is deprecated"
- ):
- result = ds.coarsen(dim={"coord": 5}, keep_attrs=True).mean(keep_attrs=False)
-
- assert result.attrs == {}
- assert result.da.attrs == {}
-
-
@pytest.mark.slow
@pytest.mark.parametrize("ds", (1, 2), indirect=True)
@pytest.mark.parametrize("window", (1, 2, 3, 4))
@@ -267,31 +234,6 @@ def test_coarsen_da_keep_attrs(funcname, argument):
assert result.name == "name"
-def test_coarsen_da_keep_attrs_deprecated():
- attrs_da = {"da_attr": "test"}
-
- data = np.linspace(10, 15, 100)
- coords = np.linspace(1, 10, 100)
-
- da = DataArray(data, dims=("coord"), coords={"coord": coords}, attrs=attrs_da)
-
- # deprecated option
- with pytest.warns(
- FutureWarning, match="Passing ``keep_attrs`` to ``coarsen`` is deprecated"
- ):
- result = da.coarsen(dim={"coord": 5}, keep_attrs=False).mean()
-
- assert result.attrs == {}
-
- # the keep_attrs in the reduction function takes precedence
- with pytest.warns(
- FutureWarning, match="Passing ``keep_attrs`` to ``coarsen`` is deprecated"
- ):
- result = da.coarsen(dim={"coord": 5}, keep_attrs=True).mean(keep_attrs=False)
-
- assert result.attrs == {}
-
-
@pytest.mark.parametrize("da", (1, 2), indirect=True)
@pytest.mark.parametrize("window", (1, 2, 3, 4))
@pytest.mark.parametrize("name", ("sum", "mean", "std", "max"))
diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py
index 09bed72496b..2439ea30b4b 100644
--- a/xarray/tests/test_computation.py
+++ b/xarray/tests/test_computation.py
@@ -1,7 +1,6 @@
import functools
import operator
import pickle
-from distutils.version import LooseVersion
import numpy as np
import pandas as pd
@@ -21,6 +20,7 @@
result_name,
unified_dim_sizes,
)
+from xarray.core.pycompat import dask_version
from . import has_dask, raise_if_dask_computes, requires_dask
@@ -1307,7 +1307,7 @@ def test_vectorize_dask_dtype_without_output_dtypes(data_array):
@pytest.mark.skipif(
- LooseVersion(dask.__version__) > "2021.06",
+ dask_version > "2021.06",
reason="dask/dask#7669: can no longer pass output_dtypes and meta",
)
@requires_dask
diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py
index f790587efa9..d5d460056aa 100644
--- a/xarray/tests/test_dask.py
+++ b/xarray/tests/test_dask.py
@@ -2,7 +2,6 @@
import pickle
import sys
from contextlib import suppress
-from distutils.version import LooseVersion
from textwrap import dedent
import numpy as np
@@ -13,6 +12,7 @@
import xarray.ufuncs as xu
from xarray import DataArray, Dataset, Variable
from xarray.core import duck_array_ops
+from xarray.core.pycompat import dask_version
from xarray.testing import assert_chunks_equal
from xarray.tests import mock
@@ -111,10 +111,7 @@ def test_indexing(self):
self.assertLazyAndIdentical(u[:1], v[:1])
self.assertLazyAndIdentical(u[[0, 1], [0, 1, 2]], v[[0, 1], [0, 1, 2]])
- @pytest.mark.skipif(
- LooseVersion(dask.__version__) < LooseVersion("2021.04.1"),
- reason="Requires dask v2021.04.1 or later",
- )
+ @pytest.mark.skipif(dask_version < "2021.04.1", reason="Requires dask >= 2021.04.1")
@pytest.mark.parametrize(
"expected_data, index",
[
@@ -133,10 +130,7 @@ def test_setitem_dask_array(self, expected_data, index):
arr[index] = 99
assert_identical(arr, expected)
- @pytest.mark.skipif(
- LooseVersion(dask.__version__) >= LooseVersion("2021.04.1"),
- reason="Requires dask v2021.04.0 or earlier",
- )
+ @pytest.mark.skipif(dask_version >= "2021.04.1", reason="Requires dask < 2021.04.1")
def test_setitem_dask_array_error(self):
with pytest.raises(TypeError, match=r"stored in a dask array"):
v = self.lazy_var
@@ -612,25 +606,6 @@ def test_dot(self):
lazy = self.lazy_array.dot(self.lazy_array[0])
self.assertLazyAndAllClose(eager, lazy)
- @pytest.mark.skipif(LooseVersion(dask.__version__) >= "2.0", reason="no meta")
- def test_dataarray_repr_legacy(self):
- data = build_dask_array("data")
- nonindex_coord = build_dask_array("coord")
- a = DataArray(data, dims=["x"], coords={"y": ("x", nonindex_coord)})
- expected = dedent(
- """\
-
- {!r}
- Coordinates:
- y (x) int64 dask.array
- Dimensions without coordinates: x""".format(
- data
- )
- )
- assert expected == repr(a)
- assert kernel_call_count == 0 # should not evaluate dask array
-
- @pytest.mark.skipif(LooseVersion(dask.__version__) < "2.0", reason="needs meta")
def test_dataarray_repr(self):
data = build_dask_array("data")
nonindex_coord = build_dask_array("coord")
@@ -648,7 +623,6 @@ def test_dataarray_repr(self):
assert expected == repr(a)
assert kernel_call_count == 0 # should not evaluate dask array
- @pytest.mark.skipif(LooseVersion(dask.__version__) < "2.0", reason="needs meta")
def test_dataset_repr(self):
data = build_dask_array("data")
nonindex_coord = build_dask_array("coord")
@@ -1619,7 +1593,7 @@ def test_more_transforms_pass_lazy_array_equiv(map_da, map_ds):
assert_equal(xr.broadcast(map_ds.cxy, map_ds.cxy)[0], map_ds.cxy)
assert_equal(map_ds.map(lambda x: x), map_ds)
assert_equal(map_ds.set_coords("a").reset_coords("a"), map_ds)
- assert_equal(map_ds.update({"a": map_ds.a}), map_ds)
+ assert_equal(map_ds.assign({"a": map_ds.a}), map_ds)
# fails because of index error
# assert_equal(
@@ -1645,7 +1619,7 @@ def test_optimize():
# The graph_manipulation module is in dask since 2021.2 but it became usable with
# xarray only since 2021.3
-@pytest.mark.skipif(LooseVersion(dask.__version__) <= "2021.02.0", reason="new module")
+@pytest.mark.skipif(dask_version <= "2021.02.0", reason="new module")
def test_graph_manipulation():
"""dask.graph_manipulation passes an optional parameter, "rename", to the rebuilder
function returned by __dask_postperist__; also, the dsk passed to the rebuilder is
diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py
index b9f04085935..8ab8bc872da 100644
--- a/xarray/tests/test_dataarray.py
+++ b/xarray/tests/test_dataarray.py
@@ -36,10 +36,12 @@
has_dask,
raise_if_dask_computes,
requires_bottleneck,
+ requires_cupy,
requires_dask,
requires_iris,
requires_numbagg,
requires_numexpr,
+ requires_pint_0_15,
requires_scipy,
requires_sparse,
source_ndarray,
@@ -148,7 +150,9 @@ def test_data_property(self):
def test_indexes(self):
array = DataArray(np.zeros((2, 3)), [("x", [0, 1]), ("y", ["a", "b", "c"])])
expected_indexes = {"x": pd.Index([0, 1]), "y": pd.Index(["a", "b", "c"])}
- expected_xindexes = {k: PandasIndex(idx) for k, idx in expected_indexes.items()}
+ expected_xindexes = {
+ k: PandasIndex(idx, k) for k, idx in expected_indexes.items()
+ }
assert array.xindexes.keys() == expected_xindexes.keys()
assert array.indexes.keys() == expected_indexes.keys()
assert all([isinstance(idx, pd.Index) for idx in array.indexes.values()])
@@ -1471,7 +1475,7 @@ def test_coords_alignment(self):
def test_set_coords_update_index(self):
actual = DataArray([1, 2, 3], [("x", [1, 2, 3])])
actual.coords["x"] = ["a", "b", "c"]
- assert actual.xindexes["x"].equals(pd.Index(["a", "b", "c"]))
+ assert actual.xindexes["x"].to_pandas_index().equals(pd.Index(["a", "b", "c"]))
def test_coords_replacement_alignment(self):
# regression test for GH725
@@ -1635,15 +1639,6 @@ def test_init_value(self):
DataArray(np.array(1), coords=[("x", np.arange(10))])
def test_swap_dims(self):
- array = DataArray(np.random.randn(3), {"y": ("x", list("abc"))}, "x")
- expected = DataArray(array.values, {"y": list("abc")}, dims="y")
- actual = array.swap_dims({"x": "y"})
- assert_identical(expected, actual)
- for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()):
- pd.testing.assert_index_equal(
- expected.xindexes[dim_name].array, actual.xindexes[dim_name].array
- )
-
array = DataArray(np.random.randn(3), {"x": list("abc")}, "x")
expected = DataArray(array.values, {"x": ("y", list("abc"))}, dims="y")
actual = array.swap_dims({"x": "y"})
@@ -6865,33 +6860,6 @@ def test_rolling_keep_attrs(funcname, argument):
assert result.name == "name"
-def test_rolling_keep_attrs_deprecated():
- attrs_da = {"da_attr": "test"}
-
- data = np.linspace(10, 15, 100)
- coords = np.linspace(1, 10, 100)
-
- da = DataArray(data, dims=("coord"), coords={"coord": coords}, attrs=attrs_da)
-
- # deprecated option
- with pytest.warns(
- FutureWarning, match="Passing ``keep_attrs`` to ``rolling`` is deprecated"
- ):
- result = da.rolling(dim={"coord": 5}, keep_attrs=False).construct("window_dim")
-
- assert result.attrs == {}
-
- # the keep_attrs in the reduction function takes precedence
- with pytest.warns(
- FutureWarning, match="Passing ``keep_attrs`` to ``rolling`` is deprecated"
- ):
- result = da.rolling(dim={"coord": 5}, keep_attrs=True).construct(
- "window_dim", keep_attrs=False
- )
-
- assert result.attrs == {}
-
-
def test_raise_no_warning_for_nan_in_binary_ops():
with pytest.warns(None) as record:
xr.DataArray([1, 2, np.NaN]) > 0
@@ -7375,3 +7343,87 @@ def test_drop_duplicates(keep):
expected = xr.DataArray(data, dims="time", coords={"time": time}, name="test")
result = ds.drop_duplicates("time", keep=keep)
assert_equal(expected, result)
+
+
+class TestNumpyCoercion:
+ # TODO once flexible indexes refactor complete also test coercion of dimension coords
+ def test_from_numpy(self):
+ da = xr.DataArray([1, 2, 3], dims="x", coords={"lat": ("x", [4, 5, 6])})
+
+ assert_identical(da.as_numpy(), da)
+ np.testing.assert_equal(da.to_numpy(), np.array([1, 2, 3]))
+ np.testing.assert_equal(da["lat"].to_numpy(), np.array([4, 5, 6]))
+
+ @requires_dask
+ def test_from_dask(self):
+ da = xr.DataArray([1, 2, 3], dims="x", coords={"lat": ("x", [4, 5, 6])})
+ da_chunked = da.chunk(1)
+
+ assert_identical(da_chunked.as_numpy(), da.compute())
+ np.testing.assert_equal(da.to_numpy(), np.array([1, 2, 3]))
+ np.testing.assert_equal(da["lat"].to_numpy(), np.array([4, 5, 6]))
+
+ @requires_pint_0_15
+ def test_from_pint(self):
+ from pint import Quantity
+
+ arr = np.array([1, 2, 3])
+ da = xr.DataArray(
+ Quantity(arr, units="Pa"),
+ dims="x",
+ coords={"lat": ("x", Quantity(arr + 3, units="m"))},
+ )
+
+ expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr + 3)})
+ assert_identical(da.as_numpy(), expected)
+ np.testing.assert_equal(da.to_numpy(), arr)
+ np.testing.assert_equal(da["lat"].to_numpy(), arr + 3)
+
+ @requires_sparse
+ def test_from_sparse(self):
+ import sparse
+
+ arr = np.diagflat([1, 2, 3])
+ sparr = sparse.COO.from_numpy(arr)
+ da = xr.DataArray(
+ sparr, dims=["x", "y"], coords={"elev": (("x", "y"), sparr + 3)}
+ )
+
+ expected = xr.DataArray(
+ arr, dims=["x", "y"], coords={"elev": (("x", "y"), arr + 3)}
+ )
+ assert_identical(da.as_numpy(), expected)
+ np.testing.assert_equal(da.to_numpy(), arr)
+
+ @requires_cupy
+ def test_from_cupy(self):
+ import cupy as cp
+
+ arr = np.array([1, 2, 3])
+ da = xr.DataArray(
+ cp.array(arr), dims="x", coords={"lat": ("x", cp.array(arr + 3))}
+ )
+
+ expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr + 3)})
+ assert_identical(da.as_numpy(), expected)
+ np.testing.assert_equal(da.to_numpy(), arr)
+
+ @requires_dask
+ @requires_pint_0_15
+ def test_from_pint_wrapping_dask(self):
+ import dask
+ from pint import Quantity
+
+ arr = np.array([1, 2, 3])
+ d = dask.array.from_array(arr)
+ da = xr.DataArray(
+ Quantity(d, units="Pa"),
+ dims="x",
+ coords={"lat": ("x", Quantity(d, units="m") * 2)},
+ )
+
+ result = da.as_numpy()
+ result.name = None # remove dask-assigned name
+ expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr * 2)})
+ assert_identical(result, expected)
+ np.testing.assert_equal(da.to_numpy(), arr)
diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py
index ac7cfa4cbb9..634c57e4124 100644
--- a/xarray/tests/test_dataset.py
+++ b/xarray/tests/test_dataset.py
@@ -44,9 +44,11 @@
has_dask,
requires_bottleneck,
requires_cftime,
+ requires_cupy,
requires_dask,
requires_numbagg,
requires_numexpr,
+ requires_pint_0_15,
requires_scipy,
requires_sparse,
source_ndarray,
@@ -728,7 +730,7 @@ def test_coords_modify(self):
def test_update_index(self):
actual = Dataset(coords={"x": [1, 2, 3]})
actual["x"] = ["a", "b", "c"]
- assert actual.xindexes["x"].equals(pd.Index(["a", "b", "c"]))
+ assert actual.xindexes["x"].to_pandas_index().equals(pd.Index(["a", "b", "c"]))
def test_coords_setitem_with_new_dimension(self):
actual = Dataset()
@@ -3219,13 +3221,13 @@ def test_update(self):
data = create_test_data(seed=0)
expected = data.copy()
var2 = Variable("dim1", np.arange(8))
- actual = data.update({"var2": var2})
+ actual = data
+ actual.update({"var2": var2})
expected["var2"] = var2
assert_identical(expected, actual)
actual = data.copy()
- actual_result = actual.update(data)
- assert actual_result is actual
+ actual.update(data)
assert_identical(expected, actual)
other = Dataset(attrs={"new": "attr"})
@@ -3585,6 +3587,7 @@ def test_setitem_align_new_indexes(self):
def test_setitem_str_dtype(self, dtype):
ds = xr.Dataset(coords={"x": np.array(["x", "y"], dtype=dtype)})
+ # test Dataset update
ds["foo"] = xr.DataArray(np.array([0, 0]), dims=["x"])
assert np.issubdtype(ds.x.dtype, dtype)
@@ -4997,6 +5000,12 @@ def test_rank(self):
with pytest.raises(ValueError, match=r"does not contain"):
x.rank("invalid_dim")
+ def test_rank_use_bottleneck(self):
+ ds = Dataset({"a": ("x", [0, np.nan, 2]), "b": ("y", [4, 6, 3, 4])})
+ with xr.set_options(use_bottleneck=False):
+ with pytest.raises(RuntimeError):
+ ds.rank("x")
+
def test_count(self):
ds = Dataset({"x": ("a", [np.nan, 1]), "y": 0, "z": np.nan})
expected = Dataset({"x": 1, "y": 1, "z": 0})
@@ -5223,10 +5232,19 @@ def test_dataset_transpose(self):
expected_dims = tuple(d for d in new_order if d in ds[k].dims)
assert actual[k].dims == expected_dims
- with pytest.raises(ValueError, match=r"permuted"):
- ds.transpose("dim1", "dim2", "dim3")
- with pytest.raises(ValueError, match=r"permuted"):
- ds.transpose("dim1", "dim2", "dim3", "time", "extra_dim")
+ # test missing dimension, raise error
+ with pytest.raises(ValueError):
+ ds.transpose(..., "not_a_dim")
+
+ # test missing dimension, ignore error
+ actual = ds.transpose(..., "not_a_dim", missing_dims="ignore")
+ expected_ell = ds.transpose(...)
+ assert_identical(expected_ell, actual)
+
+ # test missing dimension, raise warning
+ with pytest.warns(UserWarning):
+ actual = ds.transpose(..., "not_a_dim", missing_dims="warn")
+ assert_identical(expected_ell, actual)
assert "T" not in dir(ds)
@@ -6128,41 +6146,6 @@ def test_rolling_keep_attrs(funcname, argument):
assert result.da_not_rolled.name == "da_not_rolled"
-def test_rolling_keep_attrs_deprecated():
- global_attrs = {"units": "test", "long_name": "testing"}
- attrs_da = {"da_attr": "test"}
-
- data = np.linspace(10, 15, 100)
- coords = np.linspace(1, 10, 100)
-
- ds = Dataset(
- data_vars={"da": ("coord", data)},
- coords={"coord": coords},
- attrs=global_attrs,
- )
- ds.da.attrs = attrs_da
-
- # deprecated option
- with pytest.warns(
- FutureWarning, match="Passing ``keep_attrs`` to ``rolling`` is deprecated"
- ):
- result = ds.rolling(dim={"coord": 5}, keep_attrs=False).construct("window_dim")
-
- assert result.attrs == {}
- assert result.da.attrs == {}
-
- # the keep_attrs in the reduction function takes precedence
- with pytest.warns(
- FutureWarning, match="Passing ``keep_attrs`` to ``rolling`` is deprecated"
- ):
- result = ds.rolling(dim={"coord": 5}, keep_attrs=True).construct(
- "window_dim", keep_attrs=False
- )
-
- assert result.attrs == {}
- assert result.da.attrs == {}
-
-
def test_rolling_properties(ds):
# catching invalid args
with pytest.raises(ValueError, match="window must be > 0"):
@@ -6608,9 +6591,6 @@ def test_integrate(dask):
with pytest.raises(ValueError):
da.integrate("x2d")
- with pytest.warns(FutureWarning):
- da.integrate(dim="x")
-
@requires_scipy
@pytest.mark.parametrize("dask", [True, False])
@@ -6779,3 +6759,74 @@ def test_clip(ds):
result = ds.clip(min=ds.mean("y"), max=ds.mean("y"))
assert result.dims == ds.dims
+
+
+class TestNumpyCoercion:
+ def test_from_numpy(self):
+ ds = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6])})
+
+ assert_identical(ds.as_numpy(), ds)
+
+ @requires_dask
+ def test_from_dask(self):
+ ds = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6])})
+ ds_chunked = ds.chunk(1)
+
+ assert_identical(ds_chunked.as_numpy(), ds.compute())
+
+ @requires_pint_0_15
+ def test_from_pint(self):
+ from pint import Quantity
+
+ arr = np.array([1, 2, 3])
+ ds = xr.Dataset(
+ {"a": ("x", Quantity(arr, units="Pa"))},
+ coords={"lat": ("x", Quantity(arr + 3, units="m"))},
+ )
+
+ expected = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"lat": ("x", arr + 3)})
+ assert_identical(ds.as_numpy(), expected)
+
+ @requires_sparse
+ def test_from_sparse(self):
+ import sparse
+
+ arr = np.diagflat([1, 2, 3])
+ sparr = sparse.COO.from_numpy(arr)
+ ds = xr.Dataset(
+ {"a": (["x", "y"], sparr)}, coords={"elev": (("x", "y"), sparr + 3)}
+ )
+
+ expected = xr.Dataset(
+ {"a": (["x", "y"], arr)}, coords={"elev": (("x", "y"), arr + 3)}
+ )
+ assert_identical(ds.as_numpy(), expected)
+
+ @requires_cupy
+ def test_from_cupy(self):
+ import cupy as cp
+
+ arr = np.array([1, 2, 3])
+ ds = xr.Dataset(
+ {"a": ("x", cp.array(arr))}, coords={"lat": ("x", cp.array(arr + 3))}
+ )
+
+ expected = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"lat": ("x", arr + 3)})
+ assert_identical(ds.as_numpy(), expected)
+
+ @requires_dask
+ @requires_pint_0_15
+ def test_from_pint_wrapping_dask(self):
+ import dask
+ from pint import Quantity
+
+ arr = np.array([1, 2, 3])
+ d = dask.array.from_array(arr)
+ ds = xr.Dataset(
+ {"a": ("x", Quantity(d, units="Pa"))},
+ coords={"lat": ("x", Quantity(d, units="m") * 2)},
+ )
+
+ result = ds.as_numpy()
+ expected = xr.Dataset({"a": ("x", arr)}, coords={"lat": ("x", arr * 2)})
+ assert_identical(result, expected)
diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py
index 433e2e58de2..ab0d1d9f22c 100644
--- a/xarray/tests/test_distributed.py
+++ b/xarray/tests/test_distributed.py
@@ -184,11 +184,7 @@ def test_dask_distributed_cfgrib_integration_test(loop):
assert_allclose(actual, expected)
-@pytest.mark.skipif(
- distributed.__version__ <= "1.19.3",
- reason="Need recent distributed version to clean up get",
-)
-@gen_cluster(client=True, timeout=None)
+@gen_cluster(client=True)
async def test_async(c, s, a, b):
x = create_test_data()
assert not dask.is_dask_collection(x)
diff --git a/xarray/tests/test_formatting_html.py b/xarray/tests/test_formatting_html.py
index 47640ef2d95..09c6fa0cf3c 100644
--- a/xarray/tests/test_formatting_html.py
+++ b/xarray/tests/test_formatting_html.py
@@ -1,5 +1,3 @@
-from distutils.version import LooseVersion
-
import numpy as np
import pandas as pd
import pytest
@@ -57,19 +55,9 @@ def test_short_data_repr_html_non_str_keys(dataset):
def test_short_data_repr_html_dask(dask_dataarray):
- import dask
-
- if LooseVersion(dask.__version__) < "2.0.0":
- assert not hasattr(dask_dataarray.data, "_repr_html_")
- data_repr = fh.short_data_repr_html(dask_dataarray)
- assert (
- data_repr
- == "dask.array<xarray-<this-array>, shape=(4, 6), dtype=float64, chunksize=(4, 6)>"
- )
- else:
- assert hasattr(dask_dataarray.data, "_repr_html_")
- data_repr = fh.short_data_repr_html(dask_dataarray)
- assert data_repr == dask_dataarray.data._repr_html_()
+ assert hasattr(dask_dataarray.data, "_repr_html_")
+ data_repr = fh.short_data_repr_html(dask_dataarray)
+ assert data_repr == dask_dataarray.data._repr_html_()
def test_format_dims_no_dims():
diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py
index defc6212228..c8ba72a253f 100644
--- a/xarray/tests/test_indexes.py
+++ b/xarray/tests/test_indexes.py
@@ -2,7 +2,9 @@
import pandas as pd
import pytest
+import xarray as xr
from xarray.core.indexes import PandasIndex, PandasMultiIndex, _asarray_tuplesafe
+from xarray.core.variable import IndexVariable
def test_asarray_tuplesafe():
@@ -18,9 +20,57 @@ def test_asarray_tuplesafe():
class TestPandasIndex:
+ def test_constructor(self):
+ pd_idx = pd.Index([1, 2, 3])
+ index = PandasIndex(pd_idx, "x")
+
+ assert index.index is pd_idx
+ assert index.dim == "x"
+
+ def test_from_variables(self):
+ var = xr.Variable(
+ "x", [1, 2, 3], attrs={"unit": "m"}, encoding={"dtype": np.int32}
+ )
+
+ index, index_vars = PandasIndex.from_variables({"x": var})
+ xr.testing.assert_identical(var.to_index_variable(), index_vars["x"])
+ assert index.dim == "x"
+ assert index.index.equals(index_vars["x"].to_index())
+
+ var2 = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]])
+ with pytest.raises(ValueError, match=r".*only accepts one variable.*"):
+ PandasIndex.from_variables({"x": var, "foo": var2})
+
+ with pytest.raises(
+ ValueError, match=r".*only accepts a 1-dimensional variable.*"
+ ):
+ PandasIndex.from_variables({"foo": var2})
+
+ def test_from_pandas_index(self):
+ pd_idx = pd.Index([1, 2, 3], name="foo")
+
+ index, index_vars = PandasIndex.from_pandas_index(pd_idx, "x")
+
+ assert index.dim == "x"
+ assert index.index is pd_idx
+ assert index.index.name == "foo"
+ xr.testing.assert_identical(index_vars["foo"], IndexVariable("x", [1, 2, 3]))
+
+ # test no name set for pd.Index
+ pd_idx.name = None
+ index, index_vars = PandasIndex.from_pandas_index(pd_idx, "x")
+ assert "x" in index_vars
+ assert index.index is not pd_idx
+ assert index.index.name == "x"
+
+ def to_pandas_index(self):
+ pd_idx = pd.Index([1, 2, 3], name="foo")
+ index = PandasIndex(pd_idx, "x")
+ assert index.to_pandas_index() is pd_idx
+
def test_query(self):
# TODO: add tests that aren't just for edge cases
- index = PandasIndex(pd.Index([1, 2, 3]))
+ index = PandasIndex(pd.Index([1, 2, 3]), "x")
with pytest.raises(KeyError, match=r"not all values found"):
index.query({"x": [0]})
with pytest.raises(KeyError):
@@ -29,7 +79,9 @@ def test_query(self):
index.query({"x": {"one": 0}})
def test_query_datetime(self):
- index = PandasIndex(pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"]))
+ index = PandasIndex(
+ pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"]), "x"
+ )
actual = index.query({"x": "2001-01-01"})
expected = (1, None)
assert actual == expected
@@ -38,18 +90,96 @@ def test_query_datetime(self):
assert actual == expected
def test_query_unsorted_datetime_index_raises(self):
- index = PandasIndex(pd.to_datetime(["2001", "2000", "2002"]))
+ index = PandasIndex(pd.to_datetime(["2001", "2000", "2002"]), "x")
with pytest.raises(KeyError):
# pandas will try to convert this into an array indexer. We should
# raise instead, so we can be sure the result of indexing with a
# slice is always a view.
index.query({"x": slice("2001", "2002")})
+ def test_equals(self):
+ index1 = PandasIndex([1, 2, 3], "x")
+ index2 = PandasIndex([1, 2, 3], "x")
+ assert index1.equals(index2) is True
+
+ def test_union(self):
+ index1 = PandasIndex([1, 2, 3], "x")
+ index2 = PandasIndex([4, 5, 6], "y")
+ actual = index1.union(index2)
+ assert actual.index.equals(pd.Index([1, 2, 3, 4, 5, 6]))
+ assert actual.dim == "x"
+
+ def test_intersection(self):
+ index1 = PandasIndex([1, 2, 3], "x")
+ index2 = PandasIndex([2, 3, 4], "y")
+ actual = index1.intersection(index2)
+ assert actual.index.equals(pd.Index([2, 3]))
+ assert actual.dim == "x"
+
+ def test_copy(self):
+ expected = PandasIndex([1, 2, 3], "x")
+ actual = expected.copy()
+
+ assert actual.index.equals(expected.index)
+ assert actual.index is not expected.index
+ assert actual.dim == expected.dim
+
+ def test_getitem(self):
+ pd_idx = pd.Index([1, 2, 3])
+ expected = PandasIndex(pd_idx, "x")
+ actual = expected[1:]
+
+ assert actual.index.equals(pd_idx[1:])
+ assert actual.dim == expected.dim
+
class TestPandasMultiIndex:
+ def test_from_variables(self):
+ v_level1 = xr.Variable(
+ "x", [1, 2, 3], attrs={"unit": "m"}, encoding={"dtype": np.int32}
+ )
+ v_level2 = xr.Variable(
+ "x", ["a", "b", "c"], attrs={"unit": "m"}, encoding={"dtype": "U"}
+ )
+
+ index, index_vars = PandasMultiIndex.from_variables(
+ {"level1": v_level1, "level2": v_level2}
+ )
+
+ expected_idx = pd.MultiIndex.from_arrays([v_level1.data, v_level2.data])
+ assert index.dim == "x"
+ assert index.index.equals(expected_idx)
+
+ assert list(index_vars) == ["x", "level1", "level2"]
+ xr.testing.assert_equal(xr.IndexVariable("x", expected_idx), index_vars["x"])
+ xr.testing.assert_identical(v_level1.to_index_variable(), index_vars["level1"])
+ xr.testing.assert_identical(v_level2.to_index_variable(), index_vars["level2"])
+
+ var = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]])
+ with pytest.raises(
+ ValueError, match=r".*only accepts 1-dimensional variables.*"
+ ):
+ PandasMultiIndex.from_variables({"var": var})
+
+ v_level3 = xr.Variable("y", [4, 5, 6])
+ with pytest.raises(ValueError, match=r"unmatched dimensions for variables.*"):
+ PandasMultiIndex.from_variables({"level1": v_level1, "level3": v_level3})
+
+ def test_from_pandas_index(self):
+ pd_idx = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]], names=("foo", "bar"))
+
+ index, index_vars = PandasMultiIndex.from_pandas_index(pd_idx, "x")
+
+ assert index.dim == "x"
+ assert index.index is pd_idx
+ assert index.index.names == ("foo", "bar")
+ xr.testing.assert_identical(index_vars["x"], IndexVariable("x", pd_idx))
+ xr.testing.assert_identical(index_vars["foo"], IndexVariable("x", [1, 2, 3]))
+ xr.testing.assert_identical(index_vars["bar"], IndexVariable("x", [4, 5, 6]))
+
def test_query(self):
index = PandasMultiIndex(
- pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two"))
+ pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), "x"
)
# test tuples inside slice are considered as scalar indexer values
assert index.query({"x": slice(("a", 1), ("b", 2))}) == (slice(0, 4), None)
diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py
index 1909d309cf5..6e4fd320029 100644
--- a/xarray/tests/test_indexing.py
+++ b/xarray/tests/test_indexing.py
@@ -81,9 +81,12 @@ def test_group_indexers_by_index(self):
def test_remap_label_indexers(self):
def test_indexer(data, x, expected_pos, expected_idx=None):
- pos, idx = indexing.remap_label_indexers(data, {"x": x})
+ pos, new_idx_vars = indexing.remap_label_indexers(data, {"x": x})
+ idx, _ = new_idx_vars.get("x", (None, None))
+ if idx is not None:
+ idx = idx.to_pandas_index()
assert_array_equal(pos.get("x"), expected_pos)
- assert_array_equal(idx.get("x"), expected_idx)
+ assert_array_equal(idx, expected_idx)
data = Dataset({"x": ("x", [1, 2, 3])})
mindex = pd.MultiIndex.from_product(
diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py
index 4f6dc616504..2029e6af05b 100644
--- a/xarray/tests/test_interp.py
+++ b/xarray/tests/test_interp.py
@@ -727,6 +727,7 @@ def test_datetime_interp_noerror():
@requires_cftime
+@requires_scipy
def test_3641():
times = xr.cftime_range("0001", periods=3, freq="500Y")
da = xr.DataArray(range(3), dims=["time"], coords=[times])
diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py
index e2dfac04222..1ebcd9ac6f7 100644
--- a/xarray/tests/test_missing.py
+++ b/xarray/tests/test_missing.py
@@ -392,6 +392,38 @@ def test_ffill():
assert_equal(actual, expected)
+def test_ffill_use_bottleneck():
+ da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")
+ with xr.set_options(use_bottleneck=False):
+ with pytest.raises(RuntimeError):
+ da.ffill("x")
+
+
+@requires_dask
+def test_ffill_use_bottleneck_dask():
+ da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")
+ da = da.chunk({"x": 1})
+ with xr.set_options(use_bottleneck=False):
+ with pytest.raises(RuntimeError):
+ da.ffill("x")
+
+
+def test_bfill_use_bottleneck():
+ da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")
+ with xr.set_options(use_bottleneck=False):
+ with pytest.raises(RuntimeError):
+ da.bfill("x")
+
+
+@requires_dask
+def test_bfill_use_bottleneck_dask():
+ da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")
+ da = da.chunk({"x": 1})
+ with xr.set_options(use_bottleneck=False):
+ with pytest.raises(RuntimeError):
+ da.bfill("x")
+
+
@requires_bottleneck
@requires_dask
@pytest.mark.parametrize("method", ["ffill", "bfill"])
diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py
index a5ffb97db38..ee8bafb8fa7 100644
--- a/xarray/tests/test_plot.py
+++ b/xarray/tests/test_plot.py
@@ -2912,3 +2912,41 @@ def test_maybe_gca():
assert existing_axes == ax
# kwargs are ignored when reusing axes
assert ax.get_aspect() == "auto"
+
+
+@requires_matplotlib
+@pytest.mark.parametrize(
+ "x, y, z, hue, markersize, row, col, add_legend, add_colorbar",
+ [
+ ("A", "B", None, None, None, None, None, None, None),
+ ("B", "A", None, "w", None, None, None, True, None),
+ ("A", "B", None, "y", "x", None, None, True, True),
+ ("A", "B", "z", None, None, None, None, None, None),
+ ("B", "A", "z", "w", None, None, None, True, None),
+ ("A", "B", "z", "y", "x", None, None, True, True),
+ ("A", "B", "z", "y", "x", "w", None, True, True),
+ ],
+)
+def test_datarray_scatter(x, y, z, hue, markersize, row, col, add_legend, add_colorbar):
+ """Test datarray scatter. Merge with TestPlot1D eventually."""
+ ds = xr.tutorial.scatter_example_dataset()
+
+ extra_coords = [v for v in [x, hue, markersize] if v is not None]
+
+ # Base coords:
+ coords = dict(ds.coords)
+
+ # Add extra coords to the DataArray:
+ coords.update({v: ds[v] for v in extra_coords})
+
+ darray = xr.DataArray(ds[y], coords=coords)
+
+ with figure_context():
+ darray.plot._scatter(
+ x=x,
+ z=z,
+ hue=hue,
+ markersize=markersize,
+ add_legend=add_legend,
+ add_colorbar=add_colorbar,
+ )
diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py
index 17086049cc7..2140047f38e 100644
--- a/xarray/tests/test_units.py
+++ b/xarray/tests/test_units.py
@@ -5,10 +5,22 @@
import pandas as pd
import pytest
+try:
+ import matplotlib.pyplot as plt
+except ImportError:
+ pass
+
import xarray as xr
from xarray.core import dtypes, duck_array_ops
-from . import assert_allclose, assert_duckarray_allclose, assert_equal, assert_identical
+from . import (
+ assert_allclose,
+ assert_duckarray_allclose,
+ assert_equal,
+ assert_identical,
+ requires_matplotlib,
+)
+from .test_plot import PlotTestCase
from .test_variable import _PAD_XR_NP_ARGS
pint = pytest.importorskip("pint")
@@ -5564,3 +5576,29 @@ def test_merge(self, variant, unit, error, dtype):
assert_units_equal(expected, actual)
assert_equal(expected, actual)
+
+
+@requires_matplotlib
+class TestPlots(PlotTestCase):
+ def test_units_in_line_plot_labels(self):
+ arr = np.linspace(1, 10, 3) * unit_registry.Pa
+ # TODO make coord a Quantity once unit-aware indexes supported
+ x_coord = xr.DataArray(
+ np.linspace(1, 3, 3), dims="x", attrs={"units": "meters"}
+ )
+ da = xr.DataArray(data=arr, dims="x", coords={"x": x_coord}, name="pressure")
+
+ da.plot.line()
+
+ ax = plt.gca()
+ assert ax.get_ylabel() == "pressure [pascal]"
+ assert ax.get_xlabel() == "x [meters]"
+
+ def test_units_in_2d_plot_labels(self):
+ arr = np.ones((2, 3)) * unit_registry.Pa
+ da = xr.DataArray(data=arr, dims=["x", "y"], name="pressure")
+
+ fig, (ax, cax) = plt.subplots(1, 2)
+ ax = da.plot.contourf(ax=ax, cbar_ax=cax, add_colorbar=True)
+
+ assert cax.get_ylabel() == "pressure [pascal]"
diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py
index 9c78caea4d6..ce796e9de49 100644
--- a/xarray/tests/test_utils.py
+++ b/xarray/tests/test_utils.py
@@ -7,7 +7,6 @@
from xarray.coding.cftimeindex import CFTimeIndex
from xarray.core import duck_array_ops, utils
-from xarray.core.indexes import PandasIndex
from xarray.core.utils import either_dict_or_kwargs, iterate_nested
from . import assert_array_equal, requires_cftime, requires_dask
@@ -29,13 +28,11 @@ def test_safe_cast_to_index():
dates = pd.date_range("2000-01-01", periods=10)
x = np.arange(5)
td = x * np.timedelta64(1, "D")
- midx = pd.MultiIndex.from_tuples([(0,)], names=["a"])
for expected, array in [
(dates, dates.values),
(pd.Index(x, dtype=object), x.astype(object)),
(pd.Index(td), td),
(pd.Index(td, dtype=object), td.astype(object)),
- (midx, PandasIndex(midx)),
]:
actual = utils.safe_cast_to_index(array)
assert_array_equal(expected, actual)
diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py
index 1e0dff45dd2..7f3ba9123d9 100644
--- a/xarray/tests/test_variable.py
+++ b/xarray/tests/test_variable.py
@@ -11,7 +11,6 @@
from xarray import Coordinate, DataArray, Dataset, IndexVariable, Variable, set_options
from xarray.core import dtypes, duck_array_ops, indexing
from xarray.core.common import full_like, ones_like, zeros_like
-from xarray.core.indexes import PandasIndex
from xarray.core.indexing import (
BasicIndexer,
CopyOnWriteArray,
@@ -20,6 +19,7 @@
MemoryCachedArray,
NumpyIndexingAdapter,
OuterIndexer,
+ PandasIndexingAdapter,
VectorizedIndexer,
)
from xarray.core.pycompat import dask_array_type
@@ -33,7 +33,9 @@
assert_equal,
assert_identical,
raise_if_dask_computes,
+ requires_cupy,
requires_dask,
+ requires_pint_0_15,
requires_sparse,
source_ndarray,
)
@@ -535,7 +537,7 @@ def test_copy_index(self):
v = self.cls("x", midx)
for deep in [True, False]:
w = v.copy(deep=deep)
- assert isinstance(w._data, PandasIndex)
+ assert isinstance(w._data, PandasIndexingAdapter)
assert isinstance(w.to_index(), pd.MultiIndex)
assert_array_equal(v._data.array, w._data.array)
@@ -1160,7 +1162,7 @@ def test_as_variable(self):
td = np.array([timedelta(days=x) for x in range(10)])
assert as_variable(td, "time").dtype.kind == "m"
- with pytest.warns(DeprecationWarning):
+ with pytest.raises(TypeError):
as_variable(("x", DataArray([])))
def test_repr(self):
@@ -1466,6 +1468,20 @@ def test_transpose(self):
w3 = Variable(["b", "c", "d", "a"], np.einsum("abcd->bcda", x))
assert_identical(w, w3.transpose("a", "b", "c", "d"))
+ # test missing dimension, raise error
+ with pytest.raises(ValueError):
+ v.transpose(..., "not_a_dim")
+
+ # test missing dimension, ignore error
+ actual = v.transpose(..., "not_a_dim", missing_dims="ignore")
+ expected_ell = v.transpose(...)
+ assert_identical(expected_ell, actual)
+
+ # test missing dimension, raise warning
+ with pytest.warns(UserWarning):
+ v.transpose(..., "not_a_dim", missing_dims="warn")
+ assert_identical(expected_ell, actual)
+
def test_transpose_0d(self):
for value in [
3.5,
@@ -1657,6 +1673,23 @@ def test_reduce(self):
with pytest.raises(ValueError, match=r"cannot supply both"):
v.mean(dim="x", axis=0)
+ @requires_bottleneck
+ def test_reduce_use_bottleneck(self, monkeypatch):
+ def raise_if_called(*args, **kwargs):
+ raise RuntimeError("should not have been called")
+
+ import bottleneck as bn
+
+ monkeypatch.setattr(bn, "nanmin", raise_if_called)
+
+ v = Variable("x", [0.0, np.nan, 1.0])
+ with pytest.raises(RuntimeError, match="should not have been called"):
+ with set_options(use_bottleneck=True):
+ v.min()
+
+ with set_options(use_bottleneck=False):
+ v.min()
+
@pytest.mark.parametrize("skipna", [True, False])
@pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
@pytest.mark.parametrize(
@@ -1704,6 +1737,12 @@ def test_rank_dask_raises(self):
with pytest.raises(TypeError, match=r"arrays stored as dask"):
v.rank("x")
+ def test_rank_use_bottleneck(self):
+ v = Variable(["x"], [3.0, 1.0, np.nan, 2.0, 4.0])
+ with set_options(use_bottleneck=False):
+ with pytest.raises(RuntimeError):
+ v.rank("x")
+
@requires_bottleneck
def test_rank(self):
import bottleneck as bn
@@ -2145,7 +2184,7 @@ def test_multiindex_default_level_names(self):
def test_data(self):
x = IndexVariable("x", np.arange(3.0))
- assert isinstance(x._data, PandasIndex)
+ assert isinstance(x._data, PandasIndexingAdapter)
assert isinstance(x.data, np.ndarray)
assert float == x.dtype
assert_array_equal(np.arange(3), x)
@@ -2287,7 +2326,7 @@ def test_coarsen_2d(self):
class TestAsCompatibleData:
def test_unchanged_types(self):
- types = (np.asarray, PandasIndex, LazilyIndexedArray)
+ types = (np.asarray, PandasIndexingAdapter, LazilyIndexedArray)
for t in types:
for data in [
np.arange(3),
@@ -2540,3 +2579,68 @@ def test_clip(var):
var.mean("z").data[:, :, np.newaxis],
),
)
+
+
+@pytest.mark.parametrize("Var", [Variable, IndexVariable])
+class TestNumpyCoercion:
+ def test_from_numpy(self, Var):
+ v = Var("x", [1, 2, 3])
+
+ assert_identical(v.as_numpy(), v)
+ np.testing.assert_equal(v.to_numpy(), np.array([1, 2, 3]))
+
+ @requires_dask
+ def test_from_dask(self, Var):
+ v = Var("x", [1, 2, 3])
+ v_chunked = v.chunk(1)
+
+ assert_identical(v_chunked.as_numpy(), v.compute())
+ np.testing.assert_equal(v.to_numpy(), np.array([1, 2, 3]))
+
+ @requires_pint_0_15
+ def test_from_pint(self, Var):
+ from pint import Quantity
+
+ arr = np.array([1, 2, 3])
+ v = Var("x", Quantity(arr, units="m"))
+
+ assert_identical(v.as_numpy(), Var("x", arr))
+ np.testing.assert_equal(v.to_numpy(), arr)
+
+ @requires_sparse
+ def test_from_sparse(self, Var):
+ if Var is IndexVariable:
+ pytest.skip("Can't have 2D IndexVariables")
+
+ import sparse
+
+ arr = np.diagflat([1, 2, 3])
+ sparr = sparse.COO(coords=[[0, 1, 2], [0, 1, 2]], data=[1, 2, 3])
+ v = Variable(["x", "y"], sparr)
+
+ assert_identical(v.as_numpy(), Variable(["x", "y"], arr))
+ np.testing.assert_equal(v.to_numpy(), arr)
+
+ @requires_cupy
+ def test_from_cupy(self, Var):
+ import cupy as cp
+
+ arr = np.array([1, 2, 3])
+ v = Var("x", cp.array(arr))
+
+ assert_identical(v.as_numpy(), Var("x", arr))
+ np.testing.assert_equal(v.to_numpy(), arr)
+
+ @requires_dask
+ @requires_pint_0_15
+ def test_from_pint_wrapping_dask(self, Var):
+ import dask
+ from pint import Quantity
+
+ arr = np.array([1, 2, 3])
+ d = dask.array.from_array(np.array([1, 2, 3]))
+ v = Var("x", Quantity(d, units="m"))
+
+ result = v.as_numpy()
+ assert_identical(result, Var("x", arr))
+ np.testing.assert_equal(v.to_numpy(), arr)