Skip to content

Commit

Permalink
Clean up _destag_variable with respect to types and terminology
Browse files Browse the repository at this point in the history
  • Loading branch information
jthielen committed Sep 16, 2022
1 parent bbd2669 commit a6b26a3
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
9 changes: 9 additions & 0 deletions tests/test_destagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ def test_rename_staggered_coordinate(input_name, stagger_dim, unstag_dim_name, o
assert _rename_staggered_coordinate(input_name, stagger_dim, unstag_dim_name) == output_name


def test_destag_variable_dataarray():
with pytest.raises(ValueError):
_destag_variable(xr.DataArray(
np.zeros((2, 2)),
dims=('x_stag', 'y'),
coords={'x_stag': [0, 1], 'y': [0, 1]}
))


def test_destag_variable_missing_dim():
with pytest.raises(ValueError):
_destag_variable(xr.Variable(('x', 'y'), np.zeros((2, 2))), 'z_stag')
Expand Down
30 changes: 13 additions & 17 deletions xwrf/destagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,16 @@ def _destag_variable(datavar, stagger_dim=None, unstag_dim_name=None):
xarray.Variable
The destaggered variable with renamed dimension
"""
# get the coordinate to unstagger
# option 1) user has provided the dimension
if not isinstance(datavar, xr.Variable):
# Implementation expects a Variable; don't want a DataArray or other type to slip through
raise ValueError(f'Parameter datavar must be xarray.Variable, not {type(datavar)}')

# Determine dimension to unstagger
if stagger_dim and stagger_dim not in datavar.dims:
# check that the user-passed in stag dim is actually in there
# If user provided, but not actually there, error out
raise ValueError(f'{stagger_dim} not in {datavar.dims}')

# option 2) guess the staggered dimension
elif stagger_dim is None:
# guess the name of the coordinate
# If not provided, guess based on name
stagger_dim = [x for x in datavar.dims if x.endswith('_stag')]

if len(stagger_dim) > 1:
Expand All @@ -50,26 +51,21 @@ def _destag_variable(datavar, stagger_dim=None, unstag_dim_name=None):
f'{stagger_dim}'
)

# we need a string, not a list
# we need a string, not a list containing a string
stagger_dim = stagger_dim[0]
# Otherwise, we have a valid user provided stagger dimension

# get the size of the staggereed coordinate
# Destagger by mean of offset slices representing each side with respect to the stagger_dim
stagger_dim_size = datavar.sizes[stagger_dim]

# I think the "dict(a="...")" format is preferrable... but you cant stick an fx arg string
# into that...
left_or_bottom_cells = datavar.isel({stagger_dim: slice(0, stagger_dim_size - 1)})
right_or_top_cells = datavar.isel({stagger_dim: slice(1, stagger_dim_size)})
center_mean = (left_or_bottom_cells + right_or_top_cells) * 0.5

# now change the variable name of the unstaggered coordinate
# we can pass this in if we want to, for whatever reason
# Determine new dimension name; if not given, use part of original name before "_stag"
if unstag_dim_name is None:
unstag_dim_name = stagger_dim.split('_stag')[
0
] # get the part of the name before the "_stag"
unstag_dim_name = stagger_dim.split('_stag')[0]

# return a data variable with renamed dimensions
# Return a Variable with renamed dimensions, updated data and attrs, and original encoding
return xr.Variable(
dims=tuple(str(unstag_dim_name) if dim == stagger_dim else dim for dim in center_mean.dims),
data=center_mean.data,
Expand Down

0 comments on commit a6b26a3

Please sign in to comment.