diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index c3e1631c..87450586 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -667,6 +667,18 @@ def test_linear_expression_groupby(v, use_fallback): assert grouped.nterm == 10 +@pytest.mark.parametrize("use_fallback", [True]) +def test_linear_expression_groupby_ndim(z, use_fallback): + # TODO: implement fallback for n-dim groupby, see https://github.com/PyPSA/linopy/issues/299 + expr = 1 * z + groups = xr.DataArray([[1, 1, 2], [1, 3, 3]], coords=z.coords) + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "group" in grouped.dims + # there are three groups, 1, 2 and 3, the largest group has 3 elements + assert (grouped.data.group == [1, 2, 3]).all() + assert grouped.nterm == 3 + + @pytest.mark.parametrize("use_fallback", [True, False]) def test_linear_expression_groupby_with_name(v, use_fallback): expr = 1 * v