Skip to content

Commit

Permalink
Compability updates with xarray 2023.12 (#6026)
Browse files Browse the repository at this point in the history
  • Loading branch information
hoxbro authored Dec 11, 2023
1 parent 5594e74 commit 2913005
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
12 changes: 9 additions & 3 deletions holoviews/core/data/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,18 @@ def retrieve_unit_and_label(dim):
arrays[vdim.name] = arr
data = xr.Dataset(arrays)
else:
# Started to warn in xarray 2023.12.0:
# The return type of `Dataset.dims` will be changed to return a
# set of dimension names in future, in order to be more consistent
# with `DataArray.dims`. To access a mapping from dimension names to
# lengths, please use `Dataset.sizes`.
data_info = data.sizes if hasattr(data, "sizes") else data.dims
if not data.coords:
data = data.assign_coords(**{k: range(v) for k, v in data.dims.items()})
data = data.assign_coords(**{k: range(v) for k, v in data_info.items()})
if vdims is None:
vdims = list(data.data_vars)
if kdims is None:
xrdims = list(data.dims)
xrdims = list(data_info)
xrcoords = list(data.coords)
kdims = [name for name in data.indexes.keys()
if isinstance(data[name].data, np.ndarray)]
Expand Down Expand Up @@ -636,7 +642,7 @@ def length(cls, dataset):
def dframe(cls, dataset, dimensions):
import xarray as xr
if cls.packed(dataset):
bands = {vd.name: dataset.data[..., i].drop('band')
bands = {vd.name: dataset.data[..., i].drop_vars('band')
for i, vd in enumerate(dataset.vdims)}
data = xr.Dataset(bands)
else:
Expand Down
12 changes: 10 additions & 2 deletions holoviews/tests/core/data/test_xarrayinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,16 @@ def test_select_dropped_dimensions_restoration(self):
coords=dict(chain=range(d.shape[0]), value=range(d.shape[1])))
ds = Dataset(da)
t = ds.select(chain=0)
self.assertEqual(t.data.dims , dict(chain=1,value=8))
self.assertEqual(t.data.stuff.shape , (1,8))
if hasattr(t.data, "sizes"):
# Started to warn in xarray 2023.12.0:
# The return type of `Dataset.dims` will be changed to return a
# set of dimension names in future, in order to be more consistent
# with `DataArray.dims`. To access a mapping from dimension names to
# lengths, please use `Dataset.sizes`.
assert t.data.sizes == dict(chain=1, value=8)
else:
assert t.data.dims == dict(chain=1, value=8)
assert t.data.stuff.shape == (1, 8)

def test_mask_2d_array_transposed(self):
array = np.random.rand(4, 3)
Expand Down

0 comments on commit 2913005

Please sign in to comment.