Skip to content

Commit

Permalink
Fix create_grid breaking with tuple containing None
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Aug 31, 2023
1 parent 08cfea3 commit fd530aa
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
12 changes: 10 additions & 2 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,13 +885,21 @@ def test_empty_grid(self):
):
grid.create_grid()

def test_create_grid(self):
def test_create_grid_with_coords(self):
new_grid = grid.create_grid(x=self.lon, y=self.lat)

assert np.array_equal(new_grid.lat, self.lat)
assert np.array_equal(new_grid.lon, self.lon)

def test_create_grid_with_bounds(self):
def test_create_grid_with_tuple_of_coords_and_no_bounds(self):
# This case happens if `create_axis` is used without creating bounds,
# which will return a tuple of (xr.DataArray, None).
new_grid = grid.create_grid(x=(self.lon, None), y=(self.lat, None))

assert np.array_equal(new_grid.lat, self.lat)
assert np.array_equal(new_grid.lon, self.lon)

def test_create_grid_with_tuple_of_coords_and_bounds(self):
new_grid = grid.create_grid(
x=(self.lon, self.lon_bnds), y=(self.lat, self.lat_bnds)
)
Expand Down
18 changes: 10 additions & 8 deletions xcdat/regridder/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,28 +529,30 @@ def create_grid(
axes = {"x": x, "y": y, "z": z}
ds = xr.Dataset(attrs={} if attrs is None else attrs.copy())

for key, item in axes.items():
for axis, item in axes.items():
if item is None:
continue

if isinstance(item, (tuple, list)):
if len(item) != 2:
raise ValueError(
f"Argument {key!r} should be an xr.DataArray representing "
f"Argument {axis!r} should be an xr.DataArray representing "
"coordinates or a tuple (xr.DataArray, xr.DataArray) representing "
"coordinates and bounds."
)

axis, bnds = item[0].copy(deep=True), item[1].copy(deep=True) # type: ignore[union-attr]
coords = item[0].copy(deep=True)

# ensure bnds attribute is set
axis.attrs["bounds"] = bnds.name
if item[1] is not None:
bnds = item[1].copy(deep=True)

ds = ds.assign({bnds.name: bnds})
coords.attrs["bounds"] = bnds.name

ds = ds.assign({bnds.name: bnds})
else:
axis = item.copy(deep=True)
coords = item.copy(deep=True)

ds = ds.assign_coords({axis.name: axis})
ds = ds.assign_coords({coords.name: coords})

return ds

Expand Down

0 comments on commit fd530aa

Please sign in to comment.