diff --git a/xbout/geometries.py b/xbout/geometries.py index 5d98fb17..146742eb 100644 --- a/xbout/geometries.py +++ b/xbout/geometries.py @@ -5,7 +5,7 @@ import numpy as np from .region import Region, _create_regions_toroidal -from .utils import _set_attrs_on_all_vars +from .utils import _add_attrs_to_var, _set_attrs_on_all_vars REGISTERED_GEOMETRIES = {} @@ -94,6 +94,7 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None): # add_geometry_coords, in which case we do not need this. nx = updated_ds.dims[xcoord] updated_ds = updated_ds.assign_coords(**{xcoord: np.arange(nx)}) + _add_attrs_to_var(updated_ds, xcoord) ny = updated_ds.dims[ycoord] # dy should always be constant in x, so it is safe to slice to x=0. # [The y-coordinate has to be a 1d coordinate that labels x-z slices of the grid @@ -109,6 +110,7 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None): # calculate ycoord at the centre of each cell y = dy.cumsum(keep_attrs=True) - dy/2. updated_ds = updated_ds.assign_coords(**{ycoord: y.values}) + _add_attrs_to_var(updated_ds, ycoord) # If full data (not just grid file) then toroidal dim will be present if zcoord in updated_ds.dims: @@ -123,6 +125,7 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None): z = xr.DataArray(np.linspace(start=z0, stop=z1, num=nz, endpoint=False), dims=zcoord) updated_ds = updated_ds.assign_coords(**{zcoord: z}) + _add_attrs_to_var(updated_ds, zcoord) return updated_ds @@ -199,6 +202,7 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): "file name as the 'gridfilepath' argument to " "open_boutdataset().") ds[v] = grid[v] + _add_attrs_to_var(ds, v) # Rename 't' if user requested it ds = ds.rename(t=coordinates['t']) @@ -210,6 +214,7 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): # Make index 'x' a coordinate, useful for handling global indexing nx = ds.dims['x'] ds = ds.assign_coords(x=np.arange(nx)) + _add_attrs_to_var(ds, 'x') ny = ds.dims[coordinates['y']] # dy should always be constant in x, so it is safe to slice to x=0. # [The y-coordinate has to be a 1d coordinate that labels x-z slices of the grid @@ -295,6 +300,7 @@ def add_s_alpha_geometry_coords(ds, *, coordinates=None, grid=None): "file name as the 'gridfilepath' argument to " "open_boutdataset().") ds['hthe'] = grid['hthe'] + _add_attrs_to_var(ds, 'hthe') else: hthe_from_grid = False ycoord = coordinates["y"] diff --git a/xbout/tests/test_boutdataset.py b/xbout/tests/test_boutdataset.py index 90322250..251bb117 100644 --- a/xbout/tests/test_boutdataset.py +++ b/xbout/tests/test_boutdataset.py @@ -525,16 +525,6 @@ def test_reload_all(self, tmpdir_factory, bout_xyt_example_files, geometry): # Load it again recovered = reload_boutdataset(savepath) - # Compare - for coord in original.coords.values(): - # Get rid of the options if they exist, because options are not dealt with - # totally consistently: they exist if a coord was created from a variable - # loaded from the BOUT++ output, but not if the coord was calculated from - # some parameters or loaded from a grid file - try: - del coord.attrs["options"] - except KeyError: - pass xrt.assert_identical(original.load(), recovered.load()) @pytest.mark.skip("saving and loading as float32 does not work") @@ -607,15 +597,6 @@ def test_reload_separate_variables( recovered = reload_boutdataset(savepath, pre_squashed=True) # Compare - for coord in original.coords.values(): - # Get rid of the options if they exist, because options are not dealt with - # totally consistently: they exist if a coord was created from a variable - # loaded from the BOUT++ output, but not if the coord was calculated from - # some parameters or loaded from a grid file - try: - del coord.attrs["options"] - except KeyError: - pass xrt.assert_identical(recovered, original) diff --git a/xbout/utils.py b/xbout/utils.py index 771a3533..59960b4e 100644 --- a/xbout/utils.py +++ b/xbout/utils.py @@ -1,4 +1,5 @@ from copy import deepcopy +from itertools import chain import numpy as np @@ -6,14 +7,25 @@ def _set_attrs_on_all_vars(ds, key, attr_data, copy=False): ds.attrs[key] = attr_data if copy: - for da in ds.values(): - da.attrs[key] = deepcopy(attr_data) + for v in chain(ds.data_vars, ds.coords): + ds[v].attrs[key] = deepcopy(attr_data) else: - for da in ds.values(): - da.attrs[key] = attr_data + for v in chain(ds.data_vars, ds.coords): + ds[v].attrs[key] = attr_data return ds +def _add_attrs_to_var(ds, varname, copy=False): + if copy: + for attr in ["metadata", "options", "geometry", "regions"]: + if attr in ds.attrs and attr not in ds[varname].attrs: + ds[varname].attrs[attr] = deepcopy(ds.attrs[attr]) + else: + for attr in ["metadata", "options", "geometry", "regions"]: + if attr in ds.attrs and attr not in ds[varname].attrs: + ds[varname].attrs[attr] = ds.attrs[attr] + + def _check_filetype(path): if path.suffix == '.nc': filetype = 'netcdf4'