Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor dataset groupby tests #5506

Merged
merged 2 commits into from
Jun 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from unittest import mock # noqa: F401

import numpy as np
import pandas as pd
import pytest
from numpy.testing import assert_array_equal # noqa: F401
from pandas.testing import assert_frame_equal # noqa: F401

import xarray.testing
from xarray import Dataset
from xarray.core import utils
from xarray.core.duck_array_ops import allclose_or_equiv # noqa: F401
from xarray.core.indexing import ExplicitlyIndexed
Expand Down Expand Up @@ -200,3 +202,30 @@ def assert_allclose(a, b, **kwargs):
xarray.testing.assert_allclose(a, b, **kwargs)
xarray.testing._assert_internal_invariants(a)
xarray.testing._assert_internal_invariants(b)


def create_test_data(seed=None, add_attrs=True):
rs = np.random.RandomState(seed)
_vars = {
"var1": ["dim1", "dim2"],
"var2": ["dim1", "dim2"],
"var3": ["dim3", "dim1"],
}
_dims = {"dim1": 8, "dim2": 9, "dim3": 10}

obj = Dataset()
obj["dim2"] = ("dim2", 0.5 * np.arange(_dims["dim2"]))
obj["dim3"] = ("dim3", list("abcdefghij"))
obj["time"] = ("time", pd.date_range("2000-01-01", periods=20))
for v, dims in sorted(_vars.items()):
data = rs.normal(size=tuple(_dims[d] for d in dims))
obj[v] = (dims, data)
if add_attrs:
obj[v].attrs = {"foo": "variable"}
obj.coords["numbers"] = (
"dim3",
np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64"),
)
obj.encoding = {"foo": "bar"}
assert all(obj.data.flags.writeable for obj in obj.variables.values())
return obj
195 changes: 1 addition & 194 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
assert_array_equal,
assert_equal,
assert_identical,
create_test_data,
has_cftime,
has_dask,
requires_bottleneck,
Expand All @@ -62,33 +63,6 @@
]


def create_test_data(seed=None, add_attrs=True):
rs = np.random.RandomState(seed)
_vars = {
"var1": ["dim1", "dim2"],
"var2": ["dim1", "dim2"],
"var3": ["dim3", "dim1"],
}
_dims = {"dim1": 8, "dim2": 9, "dim3": 10}

obj = Dataset()
obj["dim2"] = ("dim2", 0.5 * np.arange(_dims["dim2"]))
obj["dim3"] = ("dim3", list("abcdefghij"))
obj["time"] = ("time", pd.date_range("2000-01-01", periods=20))
for v, dims in sorted(_vars.items()):
data = rs.normal(size=tuple(_dims[d] for d in dims))
obj[v] = (dims, data)
if add_attrs:
obj[v].attrs = {"foo": "variable"}
obj.coords["numbers"] = (
"dim3",
np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64"),
)
obj.encoding = {"foo": "bar"}
assert all(obj.data.flags.writeable for obj in obj.variables.values())
return obj


def create_append_test_data(seed=None):
rs = np.random.RandomState(seed)

Expand Down Expand Up @@ -3785,173 +3759,6 @@ def test_squeeze_drop(self):
selected = data.squeeze(drop=True)
assert_identical(data, selected)

def test_groupby(self):
data = Dataset(
{"z": (["x", "y"], np.random.randn(3, 5))},
{"x": ("x", list("abc")), "c": ("x", [0, 1, 0]), "y": range(5)},
)
groupby = data.groupby("x")
assert len(groupby) == 3
expected_groups = {"a": 0, "b": 1, "c": 2}
assert groupby.groups == expected_groups
expected_items = [
("a", data.isel(x=0)),
("b", data.isel(x=1)),
("c", data.isel(x=2)),
]
for actual, expected in zip(groupby, expected_items):
assert actual[0] == expected[0]
assert_equal(actual[1], expected[1])

def identity(x):
return x

for k in ["x", "c", "y"]:
actual = data.groupby(k, squeeze=False).map(identity)
assert_equal(data, actual)

def test_groupby_returns_new_type(self):
data = Dataset({"z": (["x", "y"], np.random.randn(3, 5))})

actual = data.groupby("x").map(lambda ds: ds["z"])
expected = data["z"]
assert_identical(expected, actual)

actual = data["z"].groupby("x").map(lambda x: x.to_dataset())
expected = data
assert_identical(expected, actual)

def test_groupby_iter(self):
data = create_test_data()
for n, (t, sub) in enumerate(list(data.groupby("dim1"))[:3]):
assert data["dim1"][n] == t
assert_equal(data["var1"][n], sub["var1"])
assert_equal(data["var2"][n], sub["var2"])
assert_equal(data["var3"][:, n], sub["var3"])

def test_groupby_errors(self):
data = create_test_data()
with pytest.raises(TypeError, match=r"`group` must be"):
data.groupby(np.arange(10))
with pytest.raises(ValueError, match=r"length does not match"):
data.groupby(data["dim1"][:3])
with pytest.raises(TypeError, match=r"`group` must be"):
data.groupby(data.coords["dim1"].to_index())

def test_groupby_reduce(self):
data = Dataset(
{
"xy": (["x", "y"], np.random.randn(3, 4)),
"xonly": ("x", np.random.randn(3)),
"yonly": ("y", np.random.randn(4)),
"letters": ("y", ["a", "a", "b", "b"]),
}
)

expected = data.mean("y")
expected["yonly"] = expected["yonly"].variable.set_dims({"x": 3})
actual = data.groupby("x").mean(...)
assert_allclose(expected, actual)

actual = data.groupby("x").mean("y")
assert_allclose(expected, actual)

letters = data["letters"]
expected = Dataset(
{
"xy": data["xy"].groupby(letters).mean(...),
"xonly": (data["xonly"].mean().variable.set_dims({"letters": 2})),
"yonly": data["yonly"].groupby(letters).mean(),
}
)
actual = data.groupby("letters").mean(...)
assert_allclose(expected, actual)

def test_groupby_math(self):
def reorder_dims(x):
return x.transpose("dim1", "dim2", "dim3", "time")

ds = create_test_data()
ds["dim1"] = ds["dim1"]
for squeeze in [True, False]:
grouped = ds.groupby("dim1", squeeze=squeeze)

expected = reorder_dims(ds + ds.coords["dim1"])
actual = grouped + ds.coords["dim1"]
assert_identical(expected, reorder_dims(actual))

actual = ds.coords["dim1"] + grouped
assert_identical(expected, reorder_dims(actual))

ds2 = 2 * ds
expected = reorder_dims(ds + ds2)
actual = grouped + ds2
assert_identical(expected, reorder_dims(actual))

actual = ds2 + grouped
assert_identical(expected, reorder_dims(actual))

grouped = ds.groupby("numbers")
zeros = DataArray([0, 0, 0, 0], [("numbers", range(4))])
expected = (ds + Variable("dim3", np.zeros(10))).transpose(
"dim3", "dim1", "dim2", "time"
)
actual = grouped + zeros
assert_equal(expected, actual)

actual = zeros + grouped
assert_equal(expected, actual)

with pytest.raises(ValueError, match=r"incompat.* grouped binary"):
grouped + ds
with pytest.raises(ValueError, match=r"incompat.* grouped binary"):
ds + grouped
with pytest.raises(TypeError, match=r"only support binary ops"):
grouped + 1
with pytest.raises(TypeError, match=r"only support binary ops"):
grouped + grouped
with pytest.raises(TypeError, match=r"in-place operations"):
ds += grouped

ds = Dataset(
{
"x": ("time", np.arange(100)),
"time": pd.date_range("2000-01-01", periods=100),
}
)
with pytest.raises(ValueError, match=r"incompat.* grouped binary"):
ds + ds.groupby("time.month")

def test_groupby_math_virtual(self):
ds = Dataset(
{"x": ("t", [1, 2, 3])}, {"t": pd.date_range("20100101", periods=3)}
)
grouped = ds.groupby("t.day")
actual = grouped - grouped.mean(...)
expected = Dataset({"x": ("t", [0, 0, 0])}, ds[["t", "t.day"]])
assert_identical(actual, expected)

def test_groupby_nan(self):
# nan should be excluded from groupby
ds = Dataset({"foo": ("x", [1, 2, 3, 4])}, {"bar": ("x", [1, 1, 2, np.nan])})
actual = ds.groupby("bar").mean(...)
expected = Dataset({"foo": ("bar", [1.5, 3]), "bar": [1, 2]})
assert_identical(actual, expected)

def test_groupby_order(self):
# groupby should preserve variables order
ds = Dataset()
for vn in ["a", "b", "c"]:
ds[vn] = DataArray(np.arange(10), dims=["t"])
data_vars_ref = list(ds.data_vars.keys())
ds = ds.groupby("t").mean(...)
data_vars = list(ds.data_vars.keys())
assert data_vars == data_vars_ref
# coords are now at the end of the list, so the test below fails
# all_vars = list(ds.variables.keys())
# all_vars_ref = list(ds.variables.keys())
# self.assertEqual(all_vars, all_vars_ref)

def test_resample_and_first(self):
times = pd.date_range("2000-01-01", freq="6H", periods=10)
ds = Dataset(
Expand Down
Loading