Skip to content

Commit

Permalink
to_stacked_array: better error msg & refactor (#8130)
Browse files Browse the repository at this point in the history
* to_stacked_array: better error msg & refactor

* fix regex

* Update xarray/core/dataset.py

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>

* Apply suggestions from code review

---------

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
  • Loading branch information
mathause and dcherian authored Sep 10, 2023
1 parent 336aec0 commit 0b3b20a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 23 deletions.
38 changes: 16 additions & 22 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,7 @@
)
from xarray.core.missing import get_clean_interp_index
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.parallelcompat import (
get_chunked_array_type,
guess_chunkmanager,
)
from xarray.core.parallelcompat import get_chunked_array_type, guess_chunkmanager
from xarray.core.pycompat import (
array_type,
is_chunked_array,
Expand Down Expand Up @@ -5275,34 +5272,31 @@ def to_stacked_array(

stacking_dims = tuple(dim for dim in self.dims if dim not in sample_dims)

for variable in self:
dims = self[variable].dims
dims_include_sample_dims = set(sample_dims) <= set(dims)
if not dims_include_sample_dims:
for key, da in self.data_vars.items():
missing_sample_dims = set(sample_dims) - set(da.dims)
if missing_sample_dims:
raise ValueError(
"All variables in the dataset must contain the "
f"dimensions {dims}."
"Variables in the dataset must contain all ``sample_dims`` "
f"({sample_dims!r}) but '{key}' misses {sorted(missing_sample_dims)}"
)

def ensure_stackable(val):
assign_coords = {variable_dim: val.name}
for dim in stacking_dims:
if dim not in val.dims:
assign_coords[dim] = None
def stack_dataarray(da):
# add missing dims/ coords and the name of the variable

missing_stack_coords = {variable_dim: da.name}
for dim in set(stacking_dims) - set(da.dims):
missing_stack_coords[dim] = None

expand_dims = set(stacking_dims).difference(set(val.dims))
expand_dims.add(variable_dim)
# must be list for .expand_dims
expand_dims = list(expand_dims)
missing_stack_dims = list(missing_stack_coords)

return (
val.assign_coords(**assign_coords)
.expand_dims(expand_dims)
da.assign_coords(**missing_stack_coords)
.expand_dims(missing_stack_dims)
.stack({new_dim: (variable_dim,) + stacking_dims})
)

# concatenate the arrays
stackable_vars = [ensure_stackable(self[key]) for key in self.data_vars]
stackable_vars = [stack_dataarray(da) for da in self.data_vars.values()]
data_array = concat(stackable_vars, dim=new_dim)

if name is not None:
Expand Down
5 changes: 4 additions & 1 deletion xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3792,7 +3792,10 @@ def test_to_stacked_array_invalid_sample_dims(self) -> None:
data_vars={"a": (("x", "y"), [[0, 1, 2], [3, 4, 5]]), "b": ("x", [6, 7])},
coords={"y": ["u", "v", "w"]},
)
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match=r"Variables in the dataset must contain all ``sample_dims`` \(\['y'\]\) but 'b' misses \['y'\]",
):
data.to_stacked_array("features", sample_dims=["y"])

def test_to_stacked_array_name(self) -> None:
Expand Down

0 comments on commit 0b3b20a

Please sign in to comment.