Skip to content

Commit

Permalink
More consistent attrs on coordinates
Browse files Browse the repository at this point in the history
Ensure 'metadata', 'options', 'regions' and 'geometry' attributes are
always added to all coordinates. Ensures consistency between original
and saved-and-reloaded Datasets, allowing some workarounds in tests to
be removed.
  • Loading branch information
johnomotani committed Jul 29, 2020
1 parent f313bc1 commit d062fa9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 24 deletions.
8 changes: 7 additions & 1 deletion xbout/geometries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from .region import Region, _create_regions_toroidal
from .utils import _set_attrs_on_all_vars
from .utils import _add_attrs_to_var, _set_attrs_on_all_vars

REGISTERED_GEOMETRIES = {}

Expand Down Expand Up @@ -94,6 +94,7 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None):
# add_geometry_coords, in which case we do not need this.
nx = updated_ds.dims[xcoord]
updated_ds = updated_ds.assign_coords(**{xcoord: np.arange(nx)})
_add_attrs_to_var(updated_ds, xcoord)
ny = updated_ds.dims[ycoord]
# dy should always be constant in x, so it is safe to slice to x=0.
# [The y-coordinate has to be a 1d coordinate that labels x-z slices of the grid
Expand All @@ -109,6 +110,7 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None):
# calculate ycoord at the centre of each cell
y = dy.cumsum(keep_attrs=True) - dy/2.
updated_ds = updated_ds.assign_coords(**{ycoord: y.values})
_add_attrs_to_var(updated_ds, ycoord)

# If full data (not just grid file) then toroidal dim will be present
if zcoord in updated_ds.dims:
Expand All @@ -123,6 +125,7 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None):
z = xr.DataArray(np.linspace(start=z0, stop=z1, num=nz, endpoint=False),
dims=zcoord)
updated_ds = updated_ds.assign_coords(**{zcoord: z})
_add_attrs_to_var(updated_ds, zcoord)

return updated_ds

Expand Down Expand Up @@ -199,6 +202,7 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):
"file name as the 'gridfilepath' argument to "
"open_boutdataset().")
ds[v] = grid[v]
_add_attrs_to_var(ds, v)

# Rename 't' if user requested it
ds = ds.rename(t=coordinates['t'])
Expand All @@ -210,6 +214,7 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):
# Make index 'x' a coordinate, useful for handling global indexing
nx = ds.dims['x']
ds = ds.assign_coords(x=np.arange(nx))
_add_attrs_to_var(ds, 'x')
ny = ds.dims[coordinates['y']]
# dy should always be constant in x, so it is safe to slice to x=0.
# [The y-coordinate has to be a 1d coordinate that labels x-z slices of the grid
Expand Down Expand Up @@ -295,6 +300,7 @@ def add_s_alpha_geometry_coords(ds, *, coordinates=None, grid=None):
"file name as the 'gridfilepath' argument to "
"open_boutdataset().")
ds['hthe'] = grid['hthe']
_add_attrs_to_var(ds, 'hthe')
else:
hthe_from_grid = False
ycoord = coordinates["y"]
Expand Down
19 changes: 0 additions & 19 deletions xbout/tests/test_boutdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,16 +525,6 @@ def test_reload_all(self, tmpdir_factory, bout_xyt_example_files, geometry):
# Load it again
recovered = reload_boutdataset(savepath)

# Compare
for coord in original.coords.values():
# Get rid of the options if they exist, because options are not dealt with
# totally consistently: they exist if a coord was created from a variable
# loaded from the BOUT++ output, but not if the coord was calculated from
# some parameters or loaded from a grid file
try:
del coord.attrs["options"]
except KeyError:
pass
xrt.assert_identical(original.load(), recovered.load())

@pytest.mark.skip("saving and loading as float32 does not work")
Expand Down Expand Up @@ -607,15 +597,6 @@ def test_reload_separate_variables(
recovered = reload_boutdataset(savepath, pre_squashed=True)

# Compare
for coord in original.coords.values():
# Get rid of the options if they exist, because options are not dealt with
# totally consistently: they exist if a coord was created from a variable
# loaded from the BOUT++ output, but not if the coord was calculated from
# some parameters or loaded from a grid file
try:
del coord.attrs["options"]
except KeyError:
pass
xrt.assert_identical(recovered, original)


Expand Down
20 changes: 16 additions & 4 deletions xbout/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
from copy import deepcopy
from itertools import chain

import numpy as np


def _set_attrs_on_all_vars(ds, key, attr_data, copy=False):
ds.attrs[key] = attr_data
if copy:
for da in ds.values():
da.attrs[key] = deepcopy(attr_data)
for v in chain(ds.data_vars, ds.coords):
ds[v].attrs[key] = deepcopy(attr_data)
else:
for da in ds.values():
da.attrs[key] = attr_data
for v in chain(ds.data_vars, ds.coords):
ds[v].attrs[key] = attr_data
return ds


def _add_attrs_to_var(ds, varname, copy=False):
if copy:
for attr in ["metadata", "options", "geometry", "regions"]:
if attr in ds.attrs and attr not in ds[varname].attrs:
ds[varname].attrs[attr] = deepcopy(ds.attrs[attr])
else:
for attr in ["metadata", "options", "geometry", "regions"]:
if attr in ds.attrs and attr not in ds[varname].attrs:
ds[varname].attrs[attr] = ds.attrs[attr]


def _check_filetype(path):
if path.suffix == '.nc':
filetype = 'netcdf4'
Expand Down

0 comments on commit d062fa9

Please sign in to comment.