diff --git a/.travis.yml b/.travis.yml index 524410f6..9df27765 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,12 +4,12 @@ python: - "3.7" env: - PIP_PACKAGES="setuptools pip pytest pytest-cov coverage codecov boutdata xarray!=0.14.0 numpy>=1.16.0" - - PIP_PACKAGES="setuptools pip pytest pytest-cov coverage codecov boutdata xarray==0.13.0 dask==1.0.0 numpy==1.16.0 natsort==5.5.0 matplotlib==3.1.1 animatplot==0.4.1 netcdf4==1.4.2 Pillow==6.1.0" # test with oldest supported version of packages. Note, using numpy==1.16.0 as a workaround for some weird fails on Travis, in principle we should work with numpy>=1.13.3. + - PIP_PACKAGES="setuptools pip pytest pytest-cov coverage codecov boutdata xarray==0.16.0 dask==2.10.0 numpy==1.16.0 natsort==5.5.0 matplotlib==3.1.1 animatplot==0.4.1 netcdf4==1.4.2 Pillow==6.1.0 fsspec" # test with oldest supported version of packages. Note, using numpy==1.16.0 as a workaround for some weird fails on Travis, in principle we should work with numpy>=1.13.3. We should not need to install fsspec explicitly, but at the moment are getting import errors in the tests due to fsspec not being present - should remove in future, probably when dask version is increased. install: - pip install --upgrade ${PIP_PACKAGES} - pip install -r requirements.txt - pip install -e . script: - - pytest -v --cov + - pytest -v --long --cov after_success: - codecov diff --git a/conftest.py b/conftest.py new file mode 100644 index 00000000..ffc59902 --- /dev/null +++ b/conftest.py @@ -0,0 +1,17 @@ +import pytest + + +# Add command line option '--long' for pytest, to be used to enable long tests +def pytest_addoption(parser): + parser.addoption("--long", action="store_true", default=False, + help="enable tests marked as 'long'") + + +def pytest_collection_modifyitems(config, items): + if not config.getoption("--long"): + # --long not given in cli: skip long tests + print("\n skipping long tests, pass '--long' to enable") + skip_long = pytest.mark.skip(reason="need --long option to run") + for item in items: + if "long" in item.keywords: + item.add_marker(skip_long) diff --git a/pytest.ini b/pytest.ini index ac2cecc2..0b2d89fe 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,3 +2,6 @@ filterwarnings = ignore:No geometry type found, no coordinates will be added:UserWarning ignore:deallocating CachingFileManager.*, but file is not already closed. This may indicate a bug\.:RuntimeWarning + +markers = + long: long test, or one of many permutations (disabled by default) diff --git a/requirements.txt b/requirements.txt index d293a93d..ef17102a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -xarray >= 0.13.0 -dask[array] >= 1.0.0 +xarray >= 0.16.0 +dask[array] >= 2.10.0 natsort >= 5.5.0 matplotlib >= 3.1.1 animatplot >= 0.4.1 diff --git a/setup.py b/setup.py index ec1f468b..acff2c5f 100644 --- a/setup.py +++ b/setup.py @@ -33,8 +33,8 @@ license="Apache", python_requires='>=3.6', install_requires=[ - 'xarray>=v0.13.0', - 'dask[array]>=1.0.0', + 'xarray>=0.16.0', + 'dask[array]>=2.10.0', 'natsort>=5.5.0', 'matplotlib>=3.1.1', 'animatplot>=0.4.1', diff --git a/xbout/boutdataarray.py b/xbout/boutdataarray.py index 3bfe4ce6..1368c630 100644 --- a/xbout/boutdataarray.py +++ b/xbout/boutdataarray.py @@ -13,6 +13,7 @@ from .plotting.utils import _create_norm from .region import (Region, _concat_inner_guards, _concat_outer_guards, _concat_lower_guards, _concat_upper_guards) +from .utils import _update_metadata_increased_resolution @register_dataarray_accessor('bout') @@ -48,12 +49,20 @@ def __str__(self): def to_dataset(self): """ Convert a DataArray to a Dataset, copying the attributes from the DataArray to - the Dataset. + the Dataset, and dropping attributes that only make sense for a DataArray """ da = self.data ds = da.to_dataset() - ds.attrs = da.attrs + ds.attrs = deepcopy(da.attrs) + + def dropIfExists(ds, name): + if name in ds.attrs: + del ds.attrs[name] + + dropIfExists(ds, 'direction_y') + dropIfExists(ds, 'direction_z') + dropIfExists(ds, 'cell_location') return ds @@ -117,6 +126,14 @@ def toFieldAligned(self): if self.data.direction_y != "Standard": raise ValueError("Cannot shift a " + self.direction_y + " type field to " + "field-aligned coordinates") + if hasattr(self.data, "cell_location") and not ( + self.data.cell_location == "CELL_CENTRE" + or self.data.cell_location == "CELL_ZLOW" + ): + raise ValueError( + f"toFieldAligned does not support staggered grids yet, but " + f"location is {self.data.cell_location}." + ) result = self._shiftZ(self.data['zShift']) result["direction_y"] = "Aligned" return result @@ -129,10 +146,28 @@ def fromFieldAligned(self): if self.data.direction_y != "Aligned": raise ValueError("Cannot shift a " + self.direction_y + " type field to " + "field-aligned coordinates") + if hasattr(self.data, "cell_location") and not ( + self.data.cell_location == "CELL_CENTRE" + or self.data.cell_location == "CELL_ZLOW" + ): + raise ValueError( + f"fromFieldAligned does not support staggered grids yet, but " + f"location is {self.data.cell_location}." + ) result = self._shiftZ(-self.data['zShift']) result["direction_y"] = "Standard" return result + @property + def regions(self): + if "regions" not in self.data.attrs: + raise ValueError( + "Called a method requiring regions, but these have not been created. " + "Please set the 'geometry' option when calling open_boutdataset() to " + "create regions." + ) + return self.data.attrs["regions"] + def from_region(self, name, with_guards=None): """ Get a logically-rectangular section of data from a certain region. @@ -180,6 +215,165 @@ def from_region(self, name, with_guards=None): return da + @property + def fine_interpolation_factor(self): + """ + The default factor to increase resolution when doing parallel interpolation + """ + return self.data.metadata['fine_interpolation_factor'] + + @fine_interpolation_factor.setter + def fine_interpolation_factor(self, n): + """ + Set the default factor to increase resolution when doing parallel interpolation. + + Parameters + ----------- + n : int + Factor to increase parallel resolution by + """ + self.data.metadata['fine_interpolation_factor'] = n + + def interpolate_parallel(self, region=None, *, n=None, toroidal_points=None, + method='cubic', return_dataset=False): + """ + Interpolate in the parallel direction to get a higher resolution version of the + variable. + + Parameters + ---------- + region : str, optional + By default, return a result with all regions interpolated separately and then + combined. If an explicit region argument is passed, then return the variable + from only that region. + n : int, optional + The factor to increase the resolution by. Defaults to the value set by + BoutDataset.setupParallelInterp(), or 10 if that has not been called. + toroidal_points : int or sequence of int, optional + If int, number of toroidal points to output, applies a stride to toroidal + direction to save memory usage. If sequence of int, the indexes of toroidal + points for the output. + method : str, optional + The interpolation method to use. Options from xarray.DataArray.interp(), + currently: linear, nearest, zero, slinear, quadratic, cubic. Default is + 'cubic'. + return_dataset : bool, optional + If this is set to True, return a Dataset containing this variable as a member + (by default returns a DataArray). Only used when region=None. + + Returns + ------- + A new DataArray containing a high-resolution version of the variable. (If + return_dataset=True, instead returns a Dataset containing the DataArray.) + """ + + if region is None: + # Call the single-region version of this method for each region, and combine + # the results together + parts = [ + self.interpolate_parallel(region, n=n, toroidal_points=toroidal_points, + method=method).bout.to_dataset() + for region in self.data.regions] + + # 'region' is not the same for all parts, and should not exist in the result, + # so delete before merging + for part in parts: + if 'region' in part.attrs: + del part.attrs['region'] + if 'region' in part[self.data.name].attrs: + del part[self.data.name].attrs['region'] + + result = xr.combine_by_coords(parts) + + if return_dataset: + return result + else: + # Extract the DataArray to return + return result[self.data.name] + + # Select a particular 'region' and interpolate to higher parallel resolution + da = self.data + region = da.regions[region] + tcoord = da.metadata['bout_tdim'] + xcoord = da.metadata['bout_xdim'] + ycoord = da.metadata['bout_ydim'] + zcoord = da.metadata['bout_zdim'] + + if zcoord in da.dims and da.direction_y != 'Aligned': + aligned_input = False + da = da.bout.toFieldAligned() + else: + aligned_input = True + + if n is None: + n = self.fine_interpolation_factor + + da = da.bout.from_region(region.name, with_guards={xcoord: 0, ycoord: 2}) + da = da.chunk({ycoord: None}) + + ny_fine = n*region.ny + dy = (region.yupper - region.ylower)/ny_fine + + myg = da.metadata['MYG'] + if da.metadata['keep_yboundaries'] and region.connection_lower_y is None: + ybndry_lower = myg + else: + ybndry_lower = 0 + if da.metadata['keep_yboundaries'] and region.connection_upper_y is None: + ybndry_upper = myg + else: + ybndry_upper = 0 + + y_fine = np.linspace(region.ylower - (ybndry_lower - 0.5)*dy, + region.yupper + (ybndry_upper - 0.5)*dy, + ny_fine + ybndry_lower + ybndry_upper) + + # This prevents da.interp() from being very slow. + # Apparently large attrs (i.e. regions) on a coordinate which is passed as an + # argument to dask.array.map_blocks() slow things down, maybe because coordinates + # are numpy arrays, not dask arrays? + # Slow-down was introduced in d062fa9e75c02fbfdd46e5d1104b9b12f034448f when + # _add_attrs_to_var(updated_ds, ycoord) was added in geometries.py + da[ycoord].attrs = {} + + da = da.interp({ycoord: y_fine.data}, assume_sorted=True, method=method, + kwargs={'fill_value': 'extrapolate'}) + + da = _update_metadata_increased_resolution(da, n) + + # Add dy to da as a coordinate. This will only be temporary, once we have + # combined the regions together, we will demote dy to a regular variable + dy_array = xr.DataArray(np.full([da.sizes[xcoord], da.sizes[ycoord]], dy), + dims=[xcoord, ycoord]) + # need a view of da with only x- and y-dimensions, unfortunately no neat way to + # do this with isel + da_2d = da + if tcoord in da.sizes: + da_2d = da_2d.isel(**{tcoord: 0}, drop=True) + if zcoord in da.sizes: + da_2d = da_2d.isel(**{zcoord: 0}, drop=True) + dy_array = da_2d.copy(data=dy_array) + da = da.assign_coords(dy=dy_array) + + # Remove regions which have incorrect information for the high-resolution grid. + # New regions will be generated when creating a new Dataset in + # BoutDataset.getHighParallelResVars + del da.attrs['regions'] + + if not aligned_input: + # Want output in non-aligned coordinates + da = da.bout.fromFieldAligned() + + if toroidal_points is not None and zcoord in da.sizes: + if isinstance(toroidal_points, int): + nz = len(da[zcoord]) + zstride = (nz + toroidal_points - 1)//toroidal_points + da = da.isel(**{zcoord: slice(None, None, zstride)}) + else: + da = da.isel(**{zcoord: toroidal_points}) + + return da + def animate2D(self, animate_over='t', x=None, y=None, animate=True, fps=10, save_as=None, ax=None, poloidal_plot=False, logscale=None, **kwargs): """ diff --git a/xbout/boutdataset.py b/xbout/boutdataset.py index 1fb4fe19..da2704c9 100644 --- a/xbout/boutdataset.py +++ b/xbout/boutdataset.py @@ -16,6 +16,7 @@ import numpy as np from dask.diagnostics import ProgressBar +from .geometries import apply_geometry from .plotting.animate import animate_poloidal, animate_pcolormesh, animate_line from .plotting.utils import _create_norm @@ -78,6 +79,142 @@ def getFieldAligned(self, name, caching=True): self.data[aligned_name] = self.data[name].bout.toFieldAligned() return self.data[aligned_name] + @property + def regions(self): + if "regions" not in self.data.attrs: + raise ValueError( + "Called a method requiring regions, but these have not been created. " + "Please set the 'geometry' option when calling open_boutdataset() to " + "create regions." + ) + return self.data.attrs["regions"] + + @property + def fine_interpolation_factor(self): + """ + The default factor to increase resolution when doing parallel interpolation + """ + return self.data.metadata['fine_interpolation_factor'] + + @fine_interpolation_factor.setter + def fine_interpolation_factor(self, n): + """ + Set the default factor to increase resolution when doing parallel interpolation. + + Parameters + ----------- + n : int + Factor to increase parallel resolution by + """ + ds = self.data + ds.metadata['fine_interpolation_factor'] = n + for da in ds.data_vars.values(): + da.metadata['fine_interpolation_factor'] = n + + def interpolate_parallel(self, variables, **kwargs): + """ + Interpolate in the parallel direction to get a higher resolution version of a + subset of variables. + + Note that the high-resolution variables are all loaded into memory, so most + likely it is necessary to select only a small number. The toroidal_points + argument can also be used to reduce the memory demand. + + Parameters + ---------- + variables : str or sequence of str or ... + The names of the variables to interpolate. If 'variables=...' is passed + explicitly, then interpolate all variables in the Dataset. + n : int, optional + The factor to increase the resolution by. Defaults to the value set by + BoutDataset.setupParallelInterp(), or 10 if that has not been called. + toroidal_points : int or sequence of int, optional + If int, number of toroidal points to output, applies a stride to toroidal + direction to save memory usage. If sequence of int, the indexes of toroidal + points for the output. + method : str, optional + The interpolation method to use. Options from xarray.DataArray.interp(), + currently: linear, nearest, zero, slinear, quadratic, cubic. Default is + 'cubic'. + + Returns + ------- + A new Dataset containing a high-resolution versions of the variables. The new + Dataset is a valid BoutDataset, although containing only the specified variables. + """ + + if variables is ...: + variables = [v for v in self.data] + + if isinstance(variables, str): + variables = [variables] + if isinstance(variables, tuple): + variables = list(variables) + + if 'dy' in variables: + # dy is treated specially, as it is converted to a coordinate, and then + # converted back again below, so must not call + # interpolate_parallel('dy'). + variables.remove('dy') + + # Add extra variables needed to make this a valid Dataset + if 'dx' not in variables: + variables.append('dx') + + # Need to start with a Dataset with attrs as merge() drops the attrs of the + # passed-in argument. + # Make sure the first variable has all dimensions so we don't lose any + # coordinates + def find_with_dims(first_var, dims): + if first_var is None: + dims = set(dims) + for v in variables: + if set(self.data[v].dims) == dims: + first_var = v + break + return first_var + tcoord = self.data.metadata.get("bout_tdim", "t") + zcoord = self.data.metadata.get("bout_zdim", "z") + first_var = find_with_dims(None, self.data.dims) + first_var = find_with_dims(first_var, set(self.data.dims) - set(tcoord)) + first_var = find_with_dims(first_var, set(self.data.dims) - set(zcoord)) + first_var = find_with_dims(first_var, set(self.data.dims) + - set([tcoord, zcoord])) + if first_var is None: + raise ValueError( + f"Could not find variable to interpolate with both " + f"{ds.metadata.get('bout_xdim', 'x')} and " + f"{ds.metadata.get('bout_ydim', 'y')} dimensions" + ) + variables.remove(first_var) + ds = self.data[first_var].bout.interpolate_parallel(return_dataset=True, + **kwargs) + xcoord = ds.metadata.get("bout_xdim", "x") + ycoord = ds.metadata.get("bout_ydim", "y") + for var in variables: + da = self.data[var] + if xcoord in da.dims and ycoord in da.dims: + ds = ds.merge( + da.bout.interpolate_parallel(return_dataset=True, **kwargs) + ) + elif ycoord not in da.dims: + ds[var] = da + # Can't interpolate a variable that depends on y but not x, so just skip + + # dy needs to be compatible with the new poloidal coordinate + # dy was created as a coordinate in BoutDataArray.interpolate_parallel, here just + # need to demote back to a regular variable. + ds = ds.reset_coords('dy') + + # Apply geometry + if hasattr(ds, 'geometry'): + ds = apply_geometry(ds, ds.geometry) + # if no geometry was originally applied, then ds has no geometry attribute and we + # can continue without applying geometry here + + return ds + + def save(self, savepath='./boutdata.nc', filetype='NETCDF4', variables=None, save_dtype=None, separate_vars=False, pre_load=False): """ @@ -160,7 +297,10 @@ def dict_to_attrs(obj, section): if 'regions' in to_save.attrs: # Do not need to save regions as these can be reconstructed from the metadata - del to_save.attrs['regions'] + try: + del to_save.attrs['regions'] + except KeyError: + pass for var in chain(to_save.data_vars, to_save.coords): try: del to_save[var].attrs['regions'] diff --git a/xbout/geometries.py b/xbout/geometries.py index 6a035350..9baa4506 100644 --- a/xbout/geometries.py +++ b/xbout/geometries.py @@ -5,7 +5,7 @@ import numpy as np from .region import Region, _create_regions_toroidal -from .utils import _set_attrs_on_all_vars +from .utils import _add_attrs_to_var, _set_attrs_on_all_vars REGISTERED_GEOMETRIES = {} @@ -68,6 +68,65 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None): else: updated_ds = add_geometry_coords(ds) + # Add global 1D coordinates + # ###################### + # Note the global coordinates used here are defined so that they are zero at + # the boundaries of the grid (where the grid includes all boundary cells), not + # necessarily the physical boundaries, because constant offsets do not matter, as + # long as these bounds are consistent with the global coordinates defined in + # Region.__init__() (we will only use these coordinates for interpolation) and it is + # simplest to calculate them with cumsum(). + tcoord = updated_ds.metadata.get('bout_tdim', 't') + xcoord = updated_ds.metadata.get('bout_xdim', 'x') + ycoord = updated_ds.metadata.get('bout_ydim', 'y') + zcoord = updated_ds.metadata.get('bout_zdim', 'z') + + if (tcoord not in ds.coords) and (tcoord in ds.dims): + # Create the time coordinate from t_array + updated_ds = updated_ds.rename({'t_array': tcoord}) + updated_ds = updated_ds.set_coords(tcoord) + + if xcoord not in updated_ds.coords: + # Make index 'x' a coordinate, useful for handling global indexing + # Note we have to use the index value, not the value calculated from 'dx' because + # 'dx' may not be consistent between different regions (e.g. core and PFR). + # For some geometries xcoord may have already been created by + # add_geometry_coords, in which case we do not need this. + nx = updated_ds.dims[xcoord] + updated_ds = updated_ds.assign_coords(**{xcoord: np.arange(nx)}) + _add_attrs_to_var(updated_ds, xcoord) + ny = updated_ds.dims[ycoord] + # dy should always be constant in x, so it is safe to slice to x=0. + # [The y-coordinate has to be a 1d coordinate that labels x-z slices of the grid + # (similarly x-coordinate is 1d coordinate that labels y-z slices and + # z-coordinate is a 1d coordinate that labels x-y slices). A coordinate might + # have different values in disconnected regions, but there are no branch-cuts + # allowed in the x-direction in BOUT++ (at least for the momement), so the + # y-coordinate has to be 1d and single-valued. Therefore similarly dy has to be + # 1d and single-valued.] Need drop=True so that the result does not have an + # x-coordinate value which prevents it being added as a coordinate. + dy = updated_ds['dy'].isel({xcoord: 0}, drop=True) + + # calculate ycoord at the centre of each cell + y = dy.cumsum(keep_attrs=True) - dy/2. + updated_ds = updated_ds.assign_coords(**{ycoord: y.values}) + _add_attrs_to_var(updated_ds, ycoord) + + # If full data (not just grid file) then toroidal dim will be present + if zcoord in updated_ds.dims: + nz = updated_ds.dims[zcoord] + z0 = 2*np.pi*updated_ds.metadata['ZMIN'] + z1 = z0 + nz*updated_ds.metadata['dz'] + if not np.isclose(z1, 2.*np.pi*updated_ds.metadata['ZMAX'], + rtol=1.e-15, atol=0.): + warn(f"Size of toroidal domain as calculated from nz*dz ({str(z1 - z0)}" + f" is not the same as 2pi*(ZMAX - ZMIN) (" + f"{2.*np.pi*updated_ds.metadata['ZMAX'] - z0}): using value from dz") + z = xr.DataArray(np.linspace(start=z0, stop=z1, num=nz, endpoint=False), + dims=zcoord) + updated_ds = updated_ds.assign_coords(**{zcoord: z}) + _add_attrs_to_var(updated_ds, zcoord) + return updated_ds @@ -121,29 +180,29 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): coordinates = _set_default_toroidal_coordinates(coordinates) + if set(coordinates.values()).issubset(set(ds.coords).union(ds.dims)): + # Loading a Dataset which already had the coordinates created for it + ds = _create_regions_toroidal(ds) + return ds + # Check whether coordinates names conflict with variables in ds - bad_names = [name for name in coordinates.values() if name in ds.data_vars] + bad_names = [name for name in coordinates.values() if name in ds and name not in ds.coords] if bad_names: raise ValueError("Coordinate names {} clash with variables in the dataset. " "Register a different geometry to provide alternative names. " "It may be useful to use the 'coordinates' argument to " "add_toroidal_geometry_coords() for this.".format(bad_names)) - if set(coordinates.values()).issubset(set(ds.coords).union(ds.dims)): - # Loading a Dataset which already had the coordinates created for it - ds = _create_regions_toroidal(ds) - return ds - # Get extra geometry information from grid file if it's not in the dump files needed_variables = ['psixy', 'Rxy', 'Zxy'] for v in needed_variables: if v not in ds: if grid is None: - raise ValueError( - f"Grid file is required to provide {v}. Pass the grid file name as " - f"the 'gridfilepath' argument to open_boutdataset()." - ) + raise ValueError("Grid file is required to provide %s. Pass the grid " + "file name as the 'gridfilepath' argument to " + "open_boutdataset().") ds[v] = grid[v] + _add_attrs_to_var(ds, v) # Rename 't' if user requested it ds = ds.rename(t=coordinates['t']) @@ -151,26 +210,6 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): # Change names of dimensions to Orthogonal Toroidal ones ds = ds.rename(y=coordinates['y']) - # Add 1D Orthogonal Toroidal coordinates - # Make index 'x' a coordinate, useful for handling global indexing - nx = ds.dims['x'] - ds = ds.assign_coords(x=np.arange(nx)) - ny = ds.dims[coordinates['y']] - # dy should always be constant in x, so it is safe to slice to x=0. - # [The y-coordinate has to be a 1d coordinate that labels x-z slices of the grid - # (similarly x-coordinate is 1d coordinate that labels y-z slices and z-coordinate is - # a 1d coordinate that labels x-y slices). A coordinate might have different values - # in disconnected regions, but there are no branch-cuts allowed in the x-direction in - # BOUT++ (at least for the momement), so the y-coordinate has to be 1d and - # single-valued. Therefore similarly dy has to be 1d and single-valued.] - # Need drop=True so that the result does not have an x-coordinate value which - # prevents it being added as a coordinate. - dy = ds['dy'].isel(x=0, drop=True) - - # calculate theta at the centre of each cell - theta = dy.cumsum(keep_attrs=True) - dy/2. - ds = ds.assign_coords(**{coordinates['y']: theta}) - # TODO automatically make this coordinate 1D in simplified cases? ds = ds.rename(psixy=coordinates['x']) ds = ds.set_coords(coordinates['x']) @@ -185,16 +224,6 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): # If full data (not just grid file) then toroidal dim will be present if 'z' in ds.dims: ds = ds.rename(z=coordinates['z']) - nz = ds.dims[coordinates['z']] - phi0 = 2*np.pi*ds.metadata['ZMIN'] - phi1 = phi0 + nz*ds.metadata['dz'] - if not np.isclose(phi1, 2.*np.pi*ds.metadata['ZMAX'], rtol=1.e-15, atol=0.): - warn(f"Size of toroidal domain as calculated from nz*dz ({phi1 - phi0}) is " - f"not the same as 2pi*(ZMAX - ZMIN) " - f"({2.*np.pi*ds.metadata['ZMAX'] - phi0}): using value from dz") - phi = xr.DataArray(np.linspace(start=phi0, stop=phi1, num=nz, endpoint=False), - dims=coordinates['z']) - ds = ds.assign_coords(**{coordinates['z']: phi}) # Record which dimension 'z' was renamed to. ds.metadata['bout_zdim'] = coordinates['z'] @@ -209,7 +238,19 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): # Add zShift as a coordinate, so that it gets interpolated along with a variable try: ds = ds.set_coords('zShift') - except KeyError: + except ValueError: + pass + try: + ds = ds.set_coords('zShift_CELL_XLOW') + except ValueError: + pass + try: + ds = ds.set_coords('zShift_CELL_YLOW') + except ValueError: + pass + try: + ds = ds.set_coords('zShift_CELL_ZLOW') + except ValueError: pass ds = _create_regions_toroidal(ds) @@ -227,30 +268,33 @@ def add_s_alpha_geometry_coords(ds, *, coordinates=None, grid=None): ds = _create_regions_toroidal(ds) return ds + + ds = add_toroidal_geometry_coords(ds, coordinates=coordinates, grid=grid) + # Get extra geometry information from grid file if it's not in the dump files # Add 'hthe' from grid file, needed below for radial coordinate if 'hthe' not in ds: hthe_from_grid = True + ycoord = "y" if grid is None: raise ValueError("Grid file is required to provide %s. Pass the grid " "file name as the 'gridfilepath' argument to " "open_boutdataset().") ds['hthe'] = grid['hthe'] + _add_attrs_to_var(ds, 'hthe') else: hthe_from_grid = False - - ds = add_toroidal_geometry_coords(ds, coordinates=coordinates, grid=grid) + ycoord = coordinates["y"] # Add 1D radial coordinate if 'r' in ds: raise ValueError("Cannot have variable 'r' in dataset when using " "geometry='s-alpha'") - ds['r'] = ds['hthe'].isel({coordinates['y']: 0}).squeeze(drop=True) + ds['r'] = ds['hthe'].isel({ycoord: 0}).squeeze(drop=True) ds['r'].attrs['units'] = 'm' - # remove x-index coordinate, don't need when we have 'r' as a radial coordinate - ds = ds.drop('x') ds = ds.set_coords('r') ds = ds.rename(x='r') + ds.metadata['bout_xdim'] = 'r' if hthe_from_grid: # remove hthe because it does not have correct metadata diff --git a/xbout/load.py b/xbout/load.py index 30d40c1c..9c913f70 100644 --- a/xbout/load.py +++ b/xbout/load.py @@ -164,6 +164,10 @@ def open_boutdataset(datapath='./BOUT.dmp.*.nc', inputfilepath=None, if run_name: ds.name = run_name + # Set some default settings that are only used in post-processing by xBOUT, not by + # BOUT++ + ds.bout.fine_interpolation_factor = 8 + if info == 'terse': print("Read in dataset from {}".format(str(Path(datapath)))) elif info: @@ -668,9 +672,6 @@ def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2): gridfilepath = Path(datapath) grid = xr.open_dataset(gridfilepath, engine=_check_filetype(gridfilepath)) - if 'z' in grid_chunks and 'z' not in grid.dims: - del grid_chunks['z'] - grid = grid.chunk(grid_chunks) # TODO find out what 'yup_xsplit' etc are in the doublenull storm file John gave me # For now drop any variables with extra dimensions @@ -707,4 +708,9 @@ def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2): grid = xr.concat((grid_lower, grid_upper), dim='y', data_vars='minimal', compat='identical', join='exact') + + if 'z' in grid_chunks and 'z' not in grid.dims: + del grid_chunks['z'] + grid = grid.chunk(grid_chunks) + return grid diff --git a/xbout/plotting/animate.py b/xbout/plotting/animate.py index c1d7a7ac..e2149eaa 100644 --- a/xbout/plotting/animate.py +++ b/xbout/plotting/animate.py @@ -121,10 +121,10 @@ def animate_poloidal(da, *, ax=None, cax=None, animate_over='t', separatrix=True targets = False if separatrix: - plot_separatrices(da_regions, ax) + plot_separatrices(da_regions, ax, x=x, y=y) if targets: - plot_targets(da_regions, ax, hatching=add_limiter_hatching) + plot_targets(da_regions, ax, x=x, y=y, hatching=add_limiter_hatching) if animate: timeline = amp.Timeline(np.arange(da.sizes[animate_over]), fps=fps) diff --git a/xbout/plotting/plotfuncs.py b/xbout/plotting/plotfuncs.py index 7b293bb0..71687c52 100644 --- a/xbout/plotting/plotfuncs.py +++ b/xbout/plotting/plotfuncs.py @@ -213,12 +213,20 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, raise ValueError('Argument passed to gridlines must be bool, int or ' 'slice. Got a ' + type(value) + ', ' + str(value)) - R_regions = [da_region['R'] for da_region in da_regions.values()] - Z_regions = [da_region['Z'] for da_region in da_regions.values()] - - for R, Z in zip(R_regions, Z_regions): - if (not da.metadata['bout_xdim'] in R.dims - and not da.metadata['bout_ydim'] in R.dims): + x_regions = [da_region[x] for da_region in da_regions.values()] + y_regions = [da_region[y] for da_region in da_regions.values()] + + for x, y in zip(x_regions, y_regions): + if ( + ( + not da.metadata['bout_xdim'] in x.dims + and not da.metadata['bout_ydim'] in x.dims + ) + or ( + not da.metadata['bout_xdim'] in y.dims + and not da.metadata['bout_ydim'] in y.dims + ) + ): # Small regions around X-point do not have segments in x- or y-directions, # so skip # Currently this region does not exist, but there is a small white gap at @@ -229,16 +237,16 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, # form dim_order = (da.metadata['bout_xdim'], da.metadata['bout_ydim']) yarg = {da.metadata['bout_ydim']: gridlines['x']} - plt.plot(R.isel(**yarg).transpose(*dim_order, transpose_coords=True), - Z.isel(**yarg).transpose(*dim_order, transpose_coords=True), + plt.plot(x.isel(**yarg).transpose(*dim_order, transpose_coords=True), + y.isel(**yarg).transpose(*dim_order, transpose_coords=True), color='k', lw=0.1) if gridlines.get('y') is not None: xarg = {da.metadata['bout_xdim']: gridlines['y']} # Need to plot transposed arrays to make gridlines that go in the # y-direction dim_order = (da.metadata['bout_ydim'], da.metadata['bout_xdim']) - plt.plot(R.isel(**xarg).transpose(*dim_order, transpose_coords=True), - Z.isel(**yarg).transpose(*dim_order, transpose_coords=True), + plt.plot(x.isel(**xarg).transpose(*dim_order, transpose_coords=True), + y.isel(**yarg).transpose(*dim_order, transpose_coords=True), color='k', lw=0.1) ax.set_title(da.name) @@ -248,9 +256,9 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, targets = False if separatrix: - plot_separatrices(da_regions, ax) + plot_separatrices(da_regions, ax, x=x, y=y) if targets: - plot_targets(da_regions, ax, hatching=add_limiter_hatching) + plot_targets(da_regions, ax, x=x, y=y, hatching=add_limiter_hatching) return artists diff --git a/xbout/plotting/utils.py b/xbout/plotting/utils.py index 66c2e956..af017986 100644 --- a/xbout/plotting/utils.py +++ b/xbout/plotting/utils.py @@ -73,7 +73,7 @@ def _is_core_only(da): return (ix1 >= nx and ix2 >= nx) -def plot_separatrices(da, ax): +def plot_separatrices(da, ax, *, x='R', y='Z'): """Plot separatrices""" if not isinstance(da, dict): @@ -90,14 +90,14 @@ def plot_separatrices(da, ax): inner = da_region.region.connection_inner_x if inner is not None: da_inner = da_regions[inner] - R = 0.5*(da_inner['R'].isel(**{xcoord: -1}) - + da_region['R'].isel(**{xcoord: 0})) - Z = 0.5*(da_inner['Z'].isel(**{xcoord: -1}) - + da_region['Z'].isel(**{xcoord: 0})) - ax.plot(R, Z, 'k--') + x_sep = 0.5*(da_inner[x].isel(**{xcoord: -1}) + + da_region[x].isel(**{xcoord: 0})) + y_sep = 0.5*(da_inner[y].isel(**{xcoord: -1}) + + da_region[y].isel(**{xcoord: 0})) + ax.plot(x_sep, y_sep, 'k--') -def plot_targets(da, ax, hatching=True): +def plot_targets(da, ax, *, x='R', y='Z', hatching=True): """Plot divertor and limiter target plates""" if not isinstance(da, dict): @@ -118,16 +118,16 @@ def plot_targets(da, ax, hatching=True): for da_region in da_regions.values(): if da_region.region.connection_lower_y is None: # lower target exists - R = da_region.coords['R'].isel(**{ycoord: y_boundary_guards}) - Z = da_region.coords['Z'].isel(**{ycoord: y_boundary_guards}) - [line] = ax.plot(R, Z, 'k-', linewidth=2) + x_target = da_region.coords[x].isel(**{ycoord: y_boundary_guards}) + y_target = da_region.coords[y].isel(**{ycoord: y_boundary_guards}) + [line] = ax.plot(x_target, y_target, 'k-', linewidth=2) if hatching: _add_hatching(line, ax) if da_region.region.connection_upper_y is None: # upper target exists - R = da_region.coords['R'].isel(**{ycoord: -y_boundary_guards - 1}) - Z = da_region.coords['Z'].isel(**{ycoord: -y_boundary_guards - 1}) - [line] = ax.plot(R, Z, 'k-', linewidth=2) + x_target = da_region.coords[x].isel(**{ycoord: -y_boundary_guards - 1}) + y_target = da_region.coords[y].isel(**{ycoord: -y_boundary_guards - 1}) + [line] = ax.plot(x_target, y_target, 'k-', linewidth=2) if hatching: _add_hatching(line, ax, reversed=True) diff --git a/xbout/region.py b/xbout/region.py index b4450966..2fe4b134 100644 --- a/xbout/region.py +++ b/xbout/region.py @@ -1,8 +1,4 @@ -from collections import OrderedDict - -import numpy as np import xarray as xr - from .utils import _set_attrs_on_all_vars @@ -56,23 +52,65 @@ def __init__(self, *, name, ds=None, xinner_ind=None, xouter_ind=None, self.connection_upper_y = connection_upper_y if ds is not None: + # self.nx, self.ny should not include boundary points. + # self.xinner, self.xouter, self.ylower, self.yupper + if ds.metadata['keep_xboundaries']: + xbndry = ds.metadata['MXG'] + if self.connection_inner_x is None: + self.nx -= xbndry + + # used to calculate x-coordinate of inner side (self.xinner) + xinner_ind += xbndry + + if self.connection_outer_x is None: + self.nx -= xbndry + + # used to calculate x-coordinate of outer side (self.xouter) + xouter_ind -= xbndry + + if ds.metadata['keep_yboundaries']: + ybndry = ds.metadata['MYG'] + if self.connection_lower_y is None: + self.ny -= ybndry + + # used to calculate y-coordinate of lower side (self.ylower) + ylower_ind += ybndry + + if self.connection_upper_y is None: + self.ny -= ybndry + + # used to calculate y-coordinate of upper side (self.yupper) + yupper_ind -= ybndry + # calculate start and end coordinates ##################################### self.xcoord = ds.metadata['bout_xdim'] self.ycoord = ds.metadata['bout_ydim'] + # Note the global coordinates used here are defined so that they are zero at + # the boundaries of the grid (where the grid includes all boundary cells), + # not necessarily the physical boundaries because constant offsets do not + # matter, as long as these bounds are consistent with the global coordinates + # defined in apply_geometry (we will only use these coordinates for + # interpolation) and it is simplest to calculate them with cumsum(). + # dx is constant in any particular region in the y-direction, so convert to a # 1d array - dx = ds['dx'].isel(**{self.ycoord: self.ylower_ind}) + # Note that this is not the same coordinate as the 'x' coordinate that is + # created by default from the x-index, as these values are set only for + # particular regions, so do not need to be consistent between different + # regions (e.g. core and PFR), so we are not forced to use just the index + # value here. + dx = ds['dx'].isel({self.ycoord: ylower_ind}) dx_cumsum = dx.cumsum() - self.xinner = dx_cumsum[xinner_ind] - dx[xinner_ind]/2. - self.xouter = dx_cumsum[xouter_ind - 1] + dx[xouter_ind - 1]/2. + self.xinner = dx_cumsum[xinner_ind] - dx[xinner_ind] + self.xouter = dx_cumsum[xouter_ind - 1] + dx[xouter_ind - 1] # dy is constant in the x-direction, so convert to a 1d array dy = ds['dy'].isel(**{self.xcoord: self.xinner_ind}) dy_cumsum = dy.cumsum() - self.ylower = dy_cumsum[ylower_ind] - dy[ylower_ind]/2. - self.yupper = dy_cumsum[yupper_ind - 1] + dy[yupper_ind - 1]/2. + self.ylower = dy_cumsum[ylower_ind] - dy[ylower_ind] + self.yupper = dy_cumsum[yupper_ind - 1] def __repr__(self): result = "\n" @@ -216,7 +254,7 @@ def _order_vars(lower, upper): def _get_topology(ds): jys11 = ds.metadata['jyseps1_1'] jys21 = ds.metadata['jyseps2_1'] - nyinner = ds.metadata['ny_inner'] + ny_inner = ds.metadata['ny_inner'] jys12 = ds.metadata['jyseps1_2'] jys22 = ds.metadata['jyseps2_2'] ny = ds.metadata['ny'] @@ -237,13 +275,13 @@ def _get_topology(ds): return 'single-null' if jys11 == jys21 and jys12 == jys22: - if jys11 < nyinner - 1 and jys22 > nyinner: + if jys11 < ny_inner - 1 and jys22 > ny_inner: return 'xpoint' else: raise ValueError('Currently unsupported topology') if ixs1 == ixs2: - if jys21 < nyinner - 1 and jys12 > nyinner: + if jys21 < ny_inner - 1 and jys12 > ny_inner: return 'connected-double-null' else: raise ValueError('Currently unsupported topology') @@ -251,31 +289,315 @@ def _get_topology(ds): return 'disconnected-double-null' -def _create_connection_x(regions, inner, outer): - regions[inner].connection_outer_x = outer - regions[outer].connection_inner_x = inner - - -def _create_connection_y(regions, lower, upper): - regions[lower].connection_upper_y = upper - regions[upper].connection_lower_y = lower +def _check_connections(regions): + for region in regions.values(): + if region.connection_inner_x is not None: + if regions[region.connection_inner_x].connection_outer_x != region.name: + raise ValueError( + f'Inner-x connection of {region.name} is ' + f'{region.connection_inner_x}, but outer-x connection of ' + f'{region.connection_inner_x} is ' + f'{regions[region.connection_inner_x].connection_outer_x}') + if region.connection_outer_x is not None: + if regions[region.connection_outer_x].connection_inner_x != region.name: + raise ValueError( + f'Inner-x connection of {region.name} is ' + f'{region.connection_outer_x}, but inner-x connection of ' + f'{region.connection_outer_x} is ' + f'{regions[region.connection_outer_x].connection_inner_x}') + if region.connection_lower_y is not None: + if regions[region.connection_lower_y].connection_upper_y != region.name: + raise ValueError( + f'Lower-y connection of {region.name} is ' + f'{region.connection_lower_y}, but upper-y connection of ' + f'{region.connection_lower_y} is ' + f'{regions[region.connection_lower_y].connection_upper_y}') + if region.connection_upper_y is not None: + if regions[region.connection_upper_y].connection_lower_y != region.name: + raise ValueError( + f'Upper-y connection of {region.name} is ' + f'{region.connection_upper_y}, but lower-y connection of ' + f'{region.connection_upper_y} is ' + f'{regions[region.connection_upper_y].connection_lower_y}') + + +topologies = {} + + +def topology_disconnected_double_null(*, ds, ixs1, ixs2, nx, jys11, jys21, ny_inner, + jys12, jys22, ny, ybndry): + regions = {} + regions['lower_inner_PFR'] = Region( + name='lower_inner_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, + ylower_ind=0, yupper_ind=jys11 + 1, + connection_outer_x='lower_inner_intersep', + connection_upper_y='lower_outer_PFR') + regions['lower_inner_intersep'] = Region( + name='lower_inner_intersep', ds=ds, xinner_ind=ixs1, xouter_ind=ixs2, + ylower_ind=0, yupper_ind=jys11 + 1, connection_inner_x='lower_inner_PFR', + connection_outer_x='lower_inner_SOL', connection_upper_y='inner_intersep') + regions['lower_inner_SOL'] = Region( + name='lower_inner_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, + ylower_ind=0, yupper_ind=jys11 + 1, + connection_inner_x='lower_inner_intersep', connection_upper_y='inner_SOL') + regions['inner_core'] = Region( + name='inner_core', ds=ds, xinner_ind=0, xouter_ind=ixs1, + ylower_ind=jys11 + 1, yupper_ind=jys21 + 1, + connection_outer_x='inner_intersep', connection_lower_y='outer_core', + connection_upper_y='outer_core') + regions['inner_intersep'] = Region( + name='inner_intersep', ds=ds, xinner_ind=ixs1, xouter_ind=ixs2, + ylower_ind=jys11 + 1, yupper_ind=jys21 + 1, connection_inner_x='inner_core', + connection_outer_x='inner_SOL', connection_lower_y='lower_inner_intersep', + connection_upper_y='outer_intersep') + regions['inner_SOL'] = Region( + name='inner_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, + ylower_ind=jys11 + 1, yupper_ind=jys21 + 1, + connection_inner_x='inner_intersep', connection_lower_y='lower_inner_SOL', + connection_upper_y='upper_inner_SOL') + regions['upper_inner_PFR'] = Region( + name='upper_inner_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, + ylower_ind=jys21 + 1, yupper_ind=ny_inner, + connection_outer_x='upper_inner_intersep', + connection_lower_y='upper_outer_PFR') + regions['upper_inner_intersep'] = Region( + name='upper_inner_intersep', ds=ds, xinner_ind=ixs1, xouter_ind=ixs2, + ylower_ind=jys21 + 1, yupper_ind=ny_inner, + connection_inner_x='upper_inner_PFR', connection_outer_x='upper_inner_SOL', + connection_lower_y='upper_outer_intersep') + regions['upper_inner_SOL'] = Region( + name='upper_inner_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, + ylower_ind=jys21 + 1, yupper_ind=ny_inner, + connection_inner_x='upper_inner_intersep', connection_lower_y='inner_SOL') + regions['upper_outer_PFR'] = Region( + name='upper_outer_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, + ylower_ind=ny_inner, yupper_ind=jys12 + 1, + connection_outer_x='upper_outer_intersep', + connection_upper_y='upper_inner_PFR') + regions['upper_outer_intersep'] = Region( + name='upper_outer_intersep', ds=ds, xinner_ind=ixs1, xouter_ind=ixs2, + ylower_ind=ny_inner, yupper_ind=jys12 + 1, + connection_inner_x='upper_outer_PFR', connection_outer_x='upper_outer_SOL', + connection_upper_y='upper_inner_intersep') + regions['upper_outer_SOL'] = Region( + name='upper_outer_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, + ylower_ind=ny_inner, yupper_ind=jys12 + 1, + connection_inner_x='upper_outer_intersep', connection_upper_y='outer_SOL') + regions['outer_core'] = Region( + name='outer_core', ds=ds, xinner_ind=0, xouter_ind=ixs1, + ylower_ind=jys12 + 1, yupper_ind=jys22 + 1, + connection_outer_x='outer_intersep', connection_lower_y='inner_core', + connection_upper_y='inner_core') + regions['outer_intersep'] = Region( + name='outer_intersep', ds=ds, xinner_ind=ixs1, xouter_ind=ixs2, + ylower_ind=jys12 + 1, yupper_ind=jys22 + 1, connection_inner_x='outer_core', + connection_outer_x='outer_SOL', connection_lower_y='inner_intersep', + connection_upper_y='lower_outer_intersep') + regions['outer_SOL'] = Region( + name='outer_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, + ylower_ind=jys12 + 1, yupper_ind=jys22 + 1, + connection_inner_x='outer_intersep', connection_lower_y='upper_outer_SOL', + connection_upper_y='lower_outer_SOL') + regions['lower_outer_PFR'] = Region( + name='lower_outer_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, + ylower_ind=jys22 + 1, yupper_ind=ny, + connection_outer_x='lower_outer_intersep', + connection_lower_y='lower_inner_PFR') + regions['lower_outer_intersep'] = Region( + name='lower_outer_intersep', ds=ds, xinner_ind=ixs1, xouter_ind=ixs2, + ylower_ind=jys22 + 1, yupper_ind=ny, connection_inner_x='lower_outer_PFR', + connection_outer_x='lower_outer_SOL', connection_lower_y='outer_intersep') + regions['lower_outer_SOL'] = Region( + name='lower_outer_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, + ylower_ind=jys22 + 1, yupper_ind=ny, + connection_inner_x='lower_outer_intersep', connection_lower_y='outer_SOL') + return regions + + +topologies['disconnected-double-null'] = topology_disconnected_double_null + + +def topology_connected_double_null(*, ds, ixs1, ixs2, nx, jys11, jys21, ny_inner, jys12, + jys22, ny, ybndry): + regions = {} + regions['lower_inner_PFR'] = Region( + name='lower_inner_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, + ylower_ind=0, yupper_ind=jys11 + 1, connection_outer_x='lower_inner_SOL', + connection_upper_y='lower_outer_PFR') + regions['lower_inner_SOL'] = Region( + name='lower_inner_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, + ylower_ind=0, yupper_ind=jys11 + 1, connection_inner_x='lower_inner_PFR', + connection_upper_y='inner_SOL') + regions['inner_core'] = Region( + name='inner_core', ds=ds, xinner_ind=0, xouter_ind=ixs1, + ylower_ind=jys11 + 1, yupper_ind=jys21 + 1, connection_outer_x='inner_SOL', + connection_lower_y='outer_core', connection_upper_y='outer_core') + regions['inner_SOL'] = Region( + name='inner_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, + ylower_ind=jys11 + 1, yupper_ind=jys21 + 1, connection_inner_x='inner_core', + connection_lower_y='lower_inner_SOL', connection_upper_y='upper_inner_SOL') + regions['upper_inner_PFR'] = Region( + name='upper_inner_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, + ylower_ind=jys21 + 1, yupper_ind=ny_inner, + connection_outer_x='upper_inner_SOL', connection_lower_y='upper_outer_PFR') + regions['upper_inner_SOL'] = Region( + name='upper_inner_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, + ylower_ind=jys21 + 1, yupper_ind=ny_inner, + connection_inner_x='upper_inner_PFR', connection_lower_y='inner_SOL') + regions['upper_outer_PFR'] = Region( + name='upper_outer_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, + ylower_ind=ny_inner, yupper_ind=jys12 + 1, + connection_outer_x='upper_outer_SOL', connection_upper_y='upper_inner_PFR') + regions['upper_outer_SOL'] = Region( + name='upper_outer_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, + ylower_ind=ny_inner, yupper_ind=jys12 + 1, + connection_inner_x='upper_outer_PFR', connection_upper_y='outer_SOL') + regions['outer_core'] = Region( + name='outer_core', ds=ds, xinner_ind=0, xouter_ind=ixs1, + ylower_ind=jys12 + 1, yupper_ind=jys22 + 1, connection_outer_x='outer_SOL', + connection_lower_y='inner_core', connection_upper_y='inner_core') + regions['outer_SOL'] = Region( + name='outer_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, + ylower_ind=jys12 + 1, yupper_ind=jys22 + 1, connection_inner_x='outer_core', + connection_lower_y='upper_outer_SOL', connection_upper_y='lower_outer_SOL') + regions['lower_outer_PFR'] = Region( + name='lower_outer_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, + ylower_ind=jys22 + 1, yupper_ind=ny, connection_outer_x='lower_outer_SOL', + connection_lower_y='lower_inner_PFR') + regions['lower_outer_SOL'] = Region( + name='lower_outer_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, + ylower_ind=jys22 + 1, yupper_ind=ny, connection_inner_x='lower_outer_PFR', + connection_lower_y='outer_SOL') + return regions + + +topologies['connected-double-null'] = topology_connected_double_null + + +def topology_single_null(*, ds, ixs1, ixs2, nx, jys11, jys21, ny_inner, jys12, jys22, + ny, ybndry): + regions = {} + regions['inner_PFR'] = Region( + name='inner_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, ylower_ind=0, + yupper_ind=jys11 + 1, connection_outer_x='inner_SOL', + connection_upper_y='outer_PFR') + regions['inner_SOL'] = Region( + name='inner_SOL', ds=ds, xinner_ind=ixs1, xouter_ind=nx, ylower_ind=0, + yupper_ind=jys11 + 1, connection_inner_x='inner_PFR', + connection_upper_y='SOL') + regions['core'] = Region( + name='core', ds=ds, xinner_ind=0, xouter_ind=ixs1, ylower_ind=jys11 + 1, + yupper_ind=jys22 + 1, connection_outer_x='SOL', connection_lower_y='core', + connection_upper_y='core') + regions['SOL'] = Region( + name='SOL', ds=ds, xinner_ind=ixs1, xouter_ind=nx, ylower_ind=jys11 + 1, + yupper_ind=jys22 + 1, connection_inner_x='core', + connection_lower_y='inner_SOL', connection_upper_y='outer_SOL') + regions['outer_PFR'] = Region( + name='outer_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, + ylower_ind=jys22 + 1, yupper_ind=ny, connection_outer_x='outer_SOL', + connection_lower_y='inner_PFR') + regions['outer_SOL'] = Region( + name='outer_SOL', ds=ds, xinner_ind=ixs1, xouter_ind=nx, + ylower_ind=jys22 + 1, yupper_ind=ny, connection_inner_x='outer_PFR', + connection_lower_y='SOL') + return regions + + +topologies['single-null'] = topology_single_null + + +def topology_limiter(*, ds, ixs1, ixs2, nx, jys11, jys21, ny_inner, jys12, jys22, ny, + ybndry): + regions = {} + regions['core'] = Region( + name='core', ds=ds, xinner_ind=0, xouter_ind=ixs1, ylower_ind=ybndry, + yupper_ind=ny - ybndry, connection_outer_x='SOL', connection_lower_y='core', + connection_upper_y='core') + regions['SOL'] = Region( + name='SOL', ds=ds, xinner_ind=ixs1, xouter_ind=nx, ylower_ind=0, + yupper_ind=ny, connection_inner_x='core') + return regions + + +topologies['limiter'] = topology_limiter + + +def topology_core(*, ds, ixs1, ixs2, nx, jys11, jys21, ny_inner, jys12, jys22, ny, + ybndry): + regions = {} + regions['core'] = Region( + name='core', ds=ds, xinner_ind=0, xouter_ind=nx, ylower_ind=ybndry, + yupper_ind=ny - ybndry, connection_lower_y='core', connection_upper_y='core') + return regions + + +topologies['core'] = topology_core + + +def topology_sol(*, ds, ixs1, ixs2, nx, jys11, jys21, ny_inner, jys12, jys22, ny, + ybndry): + regions = {} + regions['SOL'] = Region( + name='SOL', ds=ds, xinner_ind=0, xouter_ind=nx, ylower_ind=0, + yupper_ind=ny) + return regions + + +topologies['sol'] = topology_sol + + +def topology_xpoint(*, ds, ixs1, ixs2, nx, jys11, jys21, ny_inner, jys12, jys22, ny, + ybndry): + regions = {} + regions['lower_inner_PFR'] = Region( + name='lower_inner_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, + ylower_ind=0, yupper_ind=jys11 + 1, connection_outer_x='lower_inner_SOL', + connection_upper_y='lower_outer_PFR') + regions['lower_inner_SOL'] = Region( + name='lower_inner_SOL', ds=ds, xinner_ind=ixs1, xouter_ind=nx, + ylower_ind=0, yupper_ind=jys11 + 1, connection_inner_x='lower_inner_PFR', + connection_upper_y='upper_inner_SOL') + regions['upper_inner_PFR'] = Region( + name='upper_inner_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, + ylower_ind=jys11 + 1, yupper_ind=ny_inner, + connection_outer_x='upper_inner_SOL', connection_lower_y='upper_outer_PFR') + regions['upper_inner_SOL'] = Region( + name='upper_inner_SOL', ds=ds, xinner_ind=ixs1, xouter_ind=nx, + ylower_ind=jys11 + 1, yupper_ind=ny_inner, + connection_inner_x='upper_inner_PFR', connection_lower_y='lower_inner_SOL') + regions['upper_outer_PFR'] = Region( + name='upper_outer_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, + ylower_ind=ny_inner, yupper_ind=jys22 + 1, + connection_outer_x='upper_outer_SOL', connection_upper_y='upper_inner_PFR') + regions['upper_outer_SOL'] = Region( + name='upper_outer_SOL', ds=ds, xinner_ind=ixs1, xouter_ind=nx, + ylower_ind=ny_inner, yupper_ind=jys22 + 1, + connection_inner_x='upper_outer_PFR', connection_upper_y='lower_outer_SOL') + regions['lower_outer_PFR'] = Region( + name='lower_outer_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, + ylower_ind=jys22 + 1, yupper_ind=ny, connection_outer_x='lower_outer_SOL', + connection_lower_y='lower_inner_PFR') + regions['lower_outer_SOL'] = Region( + name='lower_outer_SOL', ds=ds, xinner_ind=ixs1, xouter_ind=nx, + ylower_ind=jys22 + 1, yupper_ind=ny, connection_inner_x='lower_outer_PFR', + connection_lower_y='upper_outer_SOL') + return regions + + +topologies['xpoint'] = topology_xpoint def _create_regions_toroidal(ds): topology = _get_topology(ds) - coordinates = {'t': ds.metadata.get('bout_tdim', None), - 'x': ds.metadata.get('bout_xdim', None), - 'y': ds.metadata.get('bout_ydim', None), - 'z': ds.metadata.get('bout_zdim', None)} - ixs1 = ds.metadata['ixseps1'] ixs2 = ds.metadata['ixseps2'] nx = ds.metadata['nx'] jys11 = ds.metadata['jyseps1_1'] jys21 = ds.metadata['jyseps2_1'] - nyinner = ds.metadata['ny_inner'] + ny_inner = ds.metadata['ny_inner'] jys12 = ds.metadata['jyseps1_2'] jys22 = ds.metadata['jyseps2_2'] ny = ds.metadata['ny'] @@ -298,7 +620,7 @@ def _create_regions_toroidal(ds): jys21 = _in_range(jys21, 0, ny - 1) jys12 = _in_range(jys12, 0, ny - 1) jys21, jys12 = _order_vars(jys21, jys12) - nyinner = _in_range(nyinner, jys21 + 1, jys12 + 1) + ny_inner = _in_range(ny_inner, jys21 + 1, jys12 + 1) jys22 = _in_range(jys22, 0, ny - 1) # Adjust for boundary cells @@ -309,222 +631,20 @@ def _create_regions_toroidal(ds): nx -= 2*mxg jys11 += ybndry jys21 += ybndry - nyinner += ybndry + ybndry_upper + ny_inner += ybndry + ybndry_upper jys12 += ybndry + 2*ybndry_upper jys22 += ybndry + 2*ybndry_upper ny += 2*ybndry + 2*ybndry_upper # Note, include guard cells in the created regions, fill them later - regions = OrderedDict() - if topology == 'disconnected-double-null': - regions['lower_inner_PFR'] = Region( - name='lower_inner_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, - ylower_ind=0, yupper_ind=jys11 + 1) - regions['lower_inner_intersep'] = Region( - name='lower_inner_intersep', ds=ds, xinner_ind=ixs1, xouter_ind=ixs2, - ylower_ind=0, yupper_ind=jys11 + 1) - regions['lower_inner_SOL'] = Region( - name='lower_inner_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, - ylower_ind=0, yupper_ind=jys11 + 1) - regions['inner_core'] = Region( - name='inner_core', ds=ds, xinner_ind=0, xouter_ind=ixs1, - ylower_ind=jys11 + 1, yupper_ind=jys21 + 1) - regions['inner_intersep'] = Region( - name='inner_intersep', ds=ds, xinner_ind=ixs1, xouter_ind=ixs2, - ylower_ind=jys11 + 1, yupper_ind=jys21 + 1) - regions['inner_SOL'] = Region( - name='inner_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, - ylower_ind=jys11 + 1, yupper_ind=jys21 + 1) - regions['upper_inner_PFR'] = Region( - name='upper_inner_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, - ylower_ind=jys21 + 1, yupper_ind=nyinner) - regions['upper_inner_intersep'] = Region( - name='upper_inner_intersep', ds=ds, xinner_ind=ixs1, xouter_ind=ixs2, - ylower_ind=jys21 + 1, yupper_ind=nyinner) - regions['upper_inner_SOL'] = Region( - name='upper_inner_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, - ylower_ind=jys21 + 1, yupper_ind=nyinner) - regions['upper_outer_PFR'] = Region( - name='upper_outer_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, - ylower_ind=nyinner, yupper_ind=jys12 + 1) - regions['upper_outer_intersep'] = Region( - name='upper_outer_intersep', ds=ds, xinner_ind=ixs1, xouter_ind=ixs2, - ylower_ind=nyinner, yupper_ind=jys12 + 1) - regions['upper_outer_SOL'] = Region( - name='upper_outer_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, - ylower_ind=nyinner, yupper_ind=jys12 + 1) - regions['outer_core'] = Region( - name='outer_core', ds=ds, xinner_ind=0, xouter_ind=ixs1, - ylower_ind=jys12 + 1, yupper_ind=jys22 + 1) - regions['outer_intersep'] = Region( - name='outer_intersep', ds=ds, xinner_ind=ixs1, xouter_ind=ixs2, - ylower_ind=jys12 + 1, yupper_ind=jys22 + 1) - regions['outer_SOL'] = Region( - name='outer_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, - ylower_ind=jys12 + 1, yupper_ind=jys22 + 1) - regions['lower_outer_PFR'] = Region( - name='lower_outer_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, - ylower_ind=jys22 + 1, yupper_ind=ny) - regions['lower_outer_intersep'] = Region( - name='lower_outer_intersep', ds=ds, xinner_ind=ixs1, xouter_ind=ixs2, - ylower_ind=jys22 + 1, yupper_ind=ny) - regions['lower_outer_SOL'] = Region( - name='lower_outer_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, - ylower_ind=jys22 + 1, yupper_ind=ny) - _create_connection_x(regions, 'lower_inner_PFR', 'lower_inner_intersep') - _create_connection_x(regions, 'lower_inner_intersep', 'lower_inner_SOL') - _create_connection_x(regions, 'inner_core', 'inner_intersep') - _create_connection_x(regions, 'inner_intersep', 'inner_SOL') - _create_connection_x(regions, 'upper_inner_PFR', 'upper_inner_intersep') - _create_connection_x(regions, 'upper_inner_intersep', 'upper_inner_SOL') - _create_connection_x(regions, 'upper_outer_PFR', 'upper_outer_intersep') - _create_connection_x(regions, 'upper_outer_intersep', 'upper_outer_SOL') - _create_connection_x(regions, 'outer_core', 'outer_intersep') - _create_connection_x(regions, 'outer_intersep', 'outer_SOL') - _create_connection_x(regions, 'lower_outer_PFR', 'lower_outer_intersep') - _create_connection_x(regions, 'lower_outer_intersep', 'lower_outer_SOL') - _create_connection_y(regions, 'lower_inner_PFR', 'lower_outer_PFR') - _create_connection_y(regions, 'lower_inner_intersep', 'inner_intersep') - _create_connection_y(regions, 'lower_inner_SOL', 'inner_SOL') - _create_connection_y(regions, 'inner_core', 'outer_core') - _create_connection_y(regions, 'outer_core', 'inner_core') - _create_connection_y(regions, 'inner_intersep', 'outer_intersep') - _create_connection_y(regions, 'inner_SOL', 'upper_inner_SOL') - _create_connection_y(regions, 'upper_outer_intersep', 'upper_inner_intersep') - _create_connection_y(regions, 'upper_outer_PFR', 'upper_inner_PFR') - _create_connection_y(regions, 'upper_outer_SOL', 'outer_SOL') - _create_connection_y(regions, 'outer_intersep', 'lower_outer_intersep') - _create_connection_y(regions, 'outer_SOL', 'lower_outer_SOL') - elif topology == 'connected-double-null': - regions['lower_inner_PFR'] = Region( - name='lower_inner_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, - ylower_ind=0, yupper_ind=jys11 + 1) - regions['lower_inner_SOL'] = Region( - name='lower_inner_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, - ylower_ind=0, yupper_ind=jys11 + 1) - regions['inner_core'] = Region( - name='inner_core', ds=ds, xinner_ind=0, xouter_ind=ixs1, - ylower_ind=jys11 + 1, yupper_ind=jys21 + 1) - regions['inner_SOL'] = Region( - name='inner_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, - ylower_ind=jys11 + 1, yupper_ind=jys21 + 1) - regions['upper_inner_PFR'] = Region( - name='upper_inner_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, - ylower_ind=jys21 + 1, yupper_ind=nyinner) - regions['upper_inner_SOL'] = Region( - name='upper_inner_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, - ylower_ind=jys21 + 1, yupper_ind=nyinner) - regions['upper_outer_PFR'] = Region( - name='upper_outer_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, - ylower_ind=nyinner, yupper_ind=jys12 + 1) - regions['upper_outer_SOL'] = Region( - name='upper_outer_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, - ylower_ind=nyinner, yupper_ind=jys12 + 1) - regions['outer_core'] = Region( - name='outer_core', ds=ds, xinner_ind=0, xouter_ind=ixs1, - ylower_ind=jys12 + 1, yupper_ind=jys22 + 1) - regions['outer_SOL'] = Region( - name='outer_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, - ylower_ind=jys12 + 1, yupper_ind=jys22 + 1) - regions['lower_outer_PFR'] = Region( - name='lower_outer_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, - ylower_ind=jys22 + 1, yupper_ind=ny) - regions['lower_outer_SOL'] = Region( - name='lower_outer_SOL', ds=ds, xinner_ind=ixs2, xouter_ind=nx, - ylower_ind=jys22 + 1, yupper_ind=ny) - _create_connection_x(regions, 'lower_inner_PFR', 'lower_inner_SOL') - _create_connection_x(regions, 'inner_core', 'inner_SOL') - _create_connection_x(regions, 'upper_inner_PFR', 'upper_inner_SOL') - _create_connection_x(regions, 'upper_outer_PFR', 'upper_outer_SOL') - _create_connection_x(regions, 'outer_core', 'outer_SOL') - _create_connection_x(regions, 'lower_outer_PFR', 'lower_outer_SOL') - _create_connection_y(regions, 'lower_inner_PFR', 'lower_outer_PFR') - _create_connection_y(regions, 'lower_inner_SOL', 'inner_SOL') - _create_connection_y(regions, 'inner_core', 'outer_core') - _create_connection_y(regions, 'outer_core', 'inner_core') - _create_connection_y(regions, 'inner_SOL', 'upper_inner_SOL') - _create_connection_y(regions, 'upper_outer_PFR', 'upper_inner_PFR') - _create_connection_y(regions, 'upper_outer_SOL', 'outer_SOL') - _create_connection_y(regions, 'outer_SOL', 'lower_outer_SOL') - elif topology == 'single-null': - regions['inner_PFR'] = Region( - name='inner_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, ylower_ind=0, - yupper_ind=jys11 + 1) - regions['inner_SOL'] = Region( - name='inner_SOL', ds=ds, xinner_ind=ixs1, xouter_ind=nx, ylower_ind=0, - yupper_ind=jys11 + 1) - regions['core'] = Region( - name='core', ds=ds, xinner_ind=0, xouter_ind=ixs1, ylower_ind=jys11 + 1, - yupper_ind=jys22 + 1) - regions['SOL'] = Region( - name='SOL', ds=ds, xinner_ind=ixs1, xouter_ind=nx, ylower_ind=jys11 + 1, - yupper_ind=jys22 + 1) - regions['outer_PFR'] = Region( - name='lower_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, - ylower_ind=jys22 + 1, yupper_ind=ny) - regions['outer_SOL'] = Region( - name='lower_SOL', ds=ds, xinner_ind=ixs1, xouter_ind=nx, - ylower_ind=jys22 + 1, yupper_ind=ny) - _create_connection_x(regions, 'inner_PFR', 'inner_SOL') - _create_connection_x(regions, 'core', 'SOL') - _create_connection_x(regions, 'outer_PFR', 'outer_SOL') - _create_connection_y(regions, 'inner_PFR', 'outer_PFR') - _create_connection_y(regions, 'inner_SOL', 'SOL') - _create_connection_y(regions, 'core', 'core') - _create_connection_y(regions, 'SOL', 'outer_SOL') - elif topology == 'limiter': - regions['core'] = Region( - name='core', ds=ds, xinner_ind=0, xouter_ind=ixs1, ylower_ind=ybndry, - yupper_ind=ny - ybndry) - regions['SOL'] = Region( - name='SOL', ds=ds, xinner_ind=ixs1, xouter_ind=nx, ylower_ind=0, - yupper_ind=ny) - _create_connection_x(regions, 'core', 'SOL') - _create_connection_y(regions, 'core', 'core') - elif topology == 'core': - regions['core'] = Region( - name='core', ds=ds, xinner_ind=0, xouter_ind=nx, ylower_ind=ybndry, - yupper_ind=ny - ybndry) - _create_connection_y(regions, 'core', 'core') - elif topology == 'sol': - regions['SOL'] = Region( - name='SOL', ds=ds, xinner_ind=0, xouter_ind=nx, ylower_ind=0, - yupper_ind=ny) - elif topology == 'xpoint': - regions['lower_inner_PFR'] = Region( - name='lower_inner_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, - ylower_ind=0, yupper_ind=jys11 + 1) - regions['lower_inner_SOL'] = Region( - name='lower_inner_SOL', ds=ds, xinner_ind=ixs1, xouter_ind=nx, - ylower_ind=0, yupper_ind=jys11 + 1) - regions['upper_inner_PFR'] = Region( - name='upper_inner_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, - ylower_ind=jys11 + 1, yupper_ind=nyinner) - regions['upper_inner_SOL'] = Region( - name='upper_inner_SOL', ds=ds, xinner_ind=ixs1, xouter_ind=nx, - ylower_ind=jys11 + 1, yupper_ind=nyinner) - regions['upper_outer_PFR'] = Region( - name='upper_outer_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, - ylower_ind=nyinner, yupper_ind=jys22 + 1) - regions['upper_outer_SOL'] = Region( - name='upper_outer_SOL', ds=ds, xinner_ind=ixs1, xouter_ind=nx, - ylower_ind=nyinner, yupper_ind=jys22 + 1) - regions['lower_outer_PFR'] = Region( - name='lower_outer_PFR', ds=ds, xinner_ind=0, xouter_ind=ixs1, - ylower_ind=jys22 + 1, yupper_ind=ny) - regions['lower_outer_SOL'] = Region( - name='lower_outer_SOL', ds=ds, xinner_ind=ixs1, xouter_ind=nx, - ylower_ind=jys22 + 1, yupper_ind=ny) - _create_connection_x(regions, 'lower_inner_PFR', 'lower_inner_SOL') - _create_connection_x(regions, 'upper_inner_PFR', 'upper_inner_SOL') - _create_connection_x(regions, 'upper_outer_PFR', 'upper_outer_SOL') - _create_connection_x(regions, 'lower_outer_PFR', 'lower_outer_SOL') - _create_connection_y(regions, 'lower_inner_PFR', 'lower_outer_PFR') - _create_connection_y(regions, 'lower_inner_SOL', 'upper_inner_SOL') - _create_connection_y(regions, 'upper_outer_PFR', 'upper_inner_PFR') - _create_connection_y(regions, 'upper_outer_SOL', 'lower_outer_SOL') - else: - raise NotImplementedError("Topology '" + topology + "' is not implemented") + try: + regions = topologies[topology](ds=ds, ixs1=ixs1, ixs2=ixs2, nx=nx, jys11=jys11, + jys21=jys21, ny_inner=ny_inner, jys12=jys12, + jys22=jys22, ny=ny, ybndry=ybndry) + except KeyError: + raise NotImplementedError(f"Topology '{topology}' is not implemented") + + _check_connections(regions) ds = _set_attrs_on_all_vars(ds, 'regions', regions) diff --git a/xbout/tests/test_boutdataarray.py b/xbout/tests/test_boutdataarray.py index 67fd03b0..27cd3b09 100644 --- a/xbout/tests/test_boutdataarray.py +++ b/xbout/tests/test_boutdataarray.py @@ -2,8 +2,11 @@ import dask.array import numpy as np -from numpy.testing import assert_allclose +import numpy.testing as npt +from pathlib import Path +import xarray as xr +import xarray.testing as xrt from xarray.core.utils import dict_equiv from xbout.tests.test_load import bout_xyt_example_files, create_bout_ds @@ -23,7 +26,10 @@ def test_to_dataset(self, tmpdir_factory, bout_xyt_example_files): assert dict_equiv(ds.attrs, new_ds.attrs) assert dict_equiv(ds.metadata, new_ds.metadata) - @pytest.mark.parametrize('nz', [6, 7, 8, 9]) + @pytest.mark.parametrize('nz', [pytest.param(6, marks=pytest.mark.long), + 7, + pytest.param(8, marks=pytest.mark.long), + pytest.param(9, marks=pytest.mark.long)]) def test_toFieldAligned(self, tmpdir_factory, bout_xyt_example_files, nz): path = bout_xyt_example_files(tmpdir_factory, lengths=(3, 3, 4, nz), nxpe=1, nype=1, nt=1) @@ -49,28 +55,28 @@ def test_toFieldAligned(self, tmpdir_factory, bout_xyt_example_files, nz): n_al = n.bout.toFieldAligned() for t in range(ds.sizes['t']): for z in range(nz): - assert_allclose(n_al[t, 0, 0, z].values, 1000.*t + z % nz, rtol=1.e-15, atol=5.e-16) # noqa: E501 + npt.assert_allclose(n_al[t, 0, 0, z].values, 1000.*t + z % nz, rtol=1.e-15, atol=5.e-16) # noqa: E501 for z in range(nz): - assert_allclose(n_al[t, 0, 1, z].values, 1000.*t + 10.*1. + (z + 1) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_al[t, 0, 1, z].values, 1000.*t + 10.*1. + (z + 1) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_al[t, 0, 2, z].values, 1000.*t + 10.*2. + (z + 2) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_al[t, 0, 2, z].values, 1000.*t + 10.*2. + (z + 2) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_al[t, 0, 3, z].values, 1000.*t + 10.*3. + (z + 3) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_al[t, 0, 3, z].values, 1000.*t + 10.*3. + (z + 3) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_al[t, 1, 0, z].values, 1000.*t + 100.*1 + 10.*0. + (z + 4) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_al[t, 1, 0, z].values, 1000.*t + 100.*1 + 10.*0. + (z + 4) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_al[t, 1, 1, z].values, 1000.*t + 100.*1 + 10.*1. + (z + 5) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_al[t, 1, 1, z].values, 1000.*t + 100.*1 + 10.*1. + (z + 5) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_al[t, 1, 2, z].values, 1000.*t + 100.*1 + 10.*2. + (z + 6) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_al[t, 1, 2, z].values, 1000.*t + 100.*1 + 10.*2. + (z + 6) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_al[t, 1, 3, z].values, 1000.*t + 100.*1 + 10.*3. + (z + 7) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_al[t, 1, 3, z].values, 1000.*t + 100.*1 + 10.*3. + (z + 7) % nz, rtol=1.e-15, atol=0.) # noqa: E501 def test_toFieldAligned_dask(self, tmpdir_factory, bout_xyt_example_files): @@ -105,30 +111,33 @@ def test_toFieldAligned_dask(self, tmpdir_factory, bout_xyt_example_files): n_al = n.bout.toFieldAligned() for t in range(ds.sizes['t']): for z in range(nz): - assert_allclose(n_al[t, 0, 0, z].values, 1000.*t + z % nz, rtol=1.e-15, atol=5.e-16) # noqa: E501 + npt.assert_allclose(n_al[t, 0, 0, z].values, 1000.*t + z % nz, rtol=1.e-15, atol=5.e-16) # noqa: E501 for z in range(nz): - assert_allclose(n_al[t, 0, 1, z].values, 1000.*t + 10.*1. + (z + 1) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_al[t, 0, 1, z].values, 1000.*t + 10.*1. + (z + 1) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_al[t, 0, 2, z].values, 1000.*t + 10.*2. + (z + 2) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_al[t, 0, 2, z].values, 1000.*t + 10.*2. + (z + 2) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_al[t, 0, 3, z].values, 1000.*t + 10.*3. + (z + 3) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_al[t, 0, 3, z].values, 1000.*t + 10.*3. + (z + 3) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_al[t, 1, 0, z].values, 1000.*t + 100.*1 + 10.*0. + (z + 4) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_al[t, 1, 0, z].values, 1000.*t + 100.*1 + 10.*0. + (z + 4) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_al[t, 1, 1, z].values, 1000.*t + 100.*1 + 10.*1. + (z + 5) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_al[t, 1, 1, z].values, 1000.*t + 100.*1 + 10.*1. + (z + 5) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_al[t, 1, 2, z].values, 1000.*t + 100.*1 + 10.*2. + (z + 6) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_al[t, 1, 2, z].values, 1000.*t + 100.*1 + 10.*2. + (z + 6) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_al[t, 1, 3, z].values, 1000.*t + 100.*1 + 10.*3. + (z + 7) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_al[t, 1, 3, z].values, 1000.*t + 100.*1 + 10.*3. + (z + 7) % nz, rtol=1.e-15, atol=0.) # noqa: E501 - @pytest.mark.parametrize('nz', [6, 7, 8, 9]) + @pytest.mark.parametrize('nz', [pytest.param(6, marks=pytest.mark.long), + 7, + pytest.param(8, marks=pytest.mark.long), + pytest.param(9, marks=pytest.mark.long)]) def test_fromFieldAligned(self, tmpdir_factory, bout_xyt_example_files, nz): path = bout_xyt_example_files(tmpdir_factory, lengths=(3, 3, 4, nz), nxpe=1, nype=1, nt=1) @@ -154,25 +163,305 @@ def test_fromFieldAligned(self, tmpdir_factory, bout_xyt_example_files, nz): n_nal = n.bout.fromFieldAligned() for t in range(ds.sizes['t']): for z in range(nz): - assert_allclose(n_nal[t, 0, 0, z].values, 1000.*t + z % nz, rtol=1.e-15, atol=5.e-16) # noqa: E501 + npt.assert_allclose(n_nal[t, 0, 0, z].values, 1000.*t + z % nz, rtol=1.e-15, atol=5.e-16) # noqa: E501 for z in range(nz): - assert_allclose(n_nal[t, 0, 1, z].values, 1000.*t + 10.*1. + (z - 1) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_nal[t, 0, 1, z].values, 1000.*t + 10.*1. + (z - 1) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_nal[t, 0, 2, z].values, 1000.*t + 10.*2. + (z - 2) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_nal[t, 0, 2, z].values, 1000.*t + 10.*2. + (z - 2) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_nal[t, 0, 3, z].values, 1000.*t + 10.*3. + (z - 3) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_nal[t, 0, 3, z].values, 1000.*t + 10.*3. + (z - 3) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_nal[t, 1, 0, z].values, 1000.*t + 100.*1 + 10.*0. + (z - 4) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_nal[t, 1, 0, z].values, 1000.*t + 100.*1 + 10.*0. + (z - 4) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_nal[t, 1, 1, z].values, 1000.*t + 100.*1 + 10.*1. + (z - 5) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_nal[t, 1, 1, z].values, 1000.*t + 100.*1 + 10.*1. + (z - 5) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_nal[t, 1, 2, z].values, 1000.*t + 100.*1 + 10.*2. + (z - 6) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_nal[t, 1, 2, z].values, 1000.*t + 100.*1 + 10.*2. + (z - 6) % nz, rtol=1.e-15, atol=0.) # noqa: E501 for z in range(nz): - assert_allclose(n_nal[t, 1, 3, z].values, 1000.*t + 100.*1 + 10.*3. + (z - 7) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + npt.assert_allclose(n_nal[t, 1, 3, z].values, 1000.*t + 100.*1 + 10.*3. + (z - 7) % nz, rtol=1.e-15, atol=0.) # noqa: E501 + + @pytest.mark.long + def test_interpolate_parallel_region_core(self, tmpdir_factory, + bout_xyt_example_files): + path = bout_xyt_example_files(tmpdir_factory, lengths=(2, 3, 16, 3), nxpe=1, + nype=1, nt=1, grid='grid', guards={'y': 2}, + topology='core') + + ds = open_boutdataset(datapath=path, + gridfilepath=Path(path).parent.joinpath('grid.nc'), + geometry='toroidal', keep_yboundaries=True) + + n = ds['n'] + + thetalength = 2.*np.pi + + dtheta = thetalength/16. + theta = xr.DataArray(np.linspace(0. - 1.5*dtheta, thetalength + 1.5*dtheta, 20), + dims='theta') + + dtheta_fine = thetalength/128. + theta_fine = xr.DataArray( + np.linspace(0. + dtheta_fine/2., thetalength - dtheta_fine/2., 128), + dims='theta') + + def f(t): + t = np.sin(t) + return (t**3 - t**2 + t - 1.) + + n.data = f(theta).broadcast_like(n) + + n_highres = n.bout.interpolate_parallel('core') + + expected = f(theta_fine).broadcast_like(n_highres) + + npt.assert_allclose(n_highres.values, expected.values, rtol=0., atol=1.e-2) + + @pytest.mark.parametrize('res_factor', [pytest.param(2, marks=pytest.mark.long), + 3, + pytest.param(7, marks=pytest.mark.long), + pytest.param(18, marks=pytest.mark.long)]) + def test_interpolate_parallel_region_core_change_n(self, tmpdir_factory, + bout_xyt_example_files, + res_factor): + path = bout_xyt_example_files(tmpdir_factory, lengths=(2, 3, 16, 3), nxpe=1, + nype=1, nt=1, grid='grid', guards={'y': 2}, + topology='core') + + ds = open_boutdataset(datapath=path, + gridfilepath=Path(path).parent.joinpath('grid.nc'), + geometry='toroidal', keep_yboundaries=True) + + n = ds['n'] + + thetalength = 2.*np.pi + + dtheta = thetalength/16. + theta = xr.DataArray(np.linspace(0. - 1.5*dtheta, thetalength + 1.5*dtheta, 20), + dims='theta') + + dtheta_fine = thetalength/res_factor/16. + theta_fine = xr.DataArray( + np.linspace(0. + dtheta_fine/2., thetalength - dtheta_fine/2., + res_factor*16), + dims='theta') + + def f(t): + t = np.sin(t) + return (t**3 - t**2 + t - 1.) + + n.data = f(theta).broadcast_like(n) + + n_highres = n.bout.interpolate_parallel('core', n=res_factor) + + expected = f(theta_fine).broadcast_like(n_highres) + + npt.assert_allclose(n_highres.values, expected.values, rtol=0., atol=1.e-2) + + @pytest.mark.long + def test_interpolate_parallel_region_sol(self, tmpdir_factory, + bout_xyt_example_files): + path = bout_xyt_example_files(tmpdir_factory, lengths=(2, 3, 16, 3), nxpe=1, + nype=1, nt=1, grid='grid', guards={'y': 2}, + topology='sol') + + ds = open_boutdataset(datapath=path, + gridfilepath=Path(path).parent.joinpath('grid.nc'), + geometry='toroidal', keep_yboundaries=True) + + n = ds['n'] + + thetalength = 2.*np.pi + + dtheta = thetalength/16. + theta = xr.DataArray(np.linspace(0. - 1.5*dtheta, thetalength + 1.5*dtheta, 20), + dims='theta') + + dtheta_fine = thetalength/128. + theta_fine = xr.DataArray( + np.linspace(0. - 1.5*dtheta_fine, thetalength + 1.5*dtheta_fine, 132), + dims='theta') + + def f(t): + t = np.sin(t) + return (t**3 - t**2 + t - 1.) + + n.data = f(theta).broadcast_like(n) + + n_highres = n.bout.interpolate_parallel('SOL') + + expected = f(theta_fine).broadcast_like(n_highres) + + npt.assert_allclose(n_highres.values, expected.values, rtol=0., atol=1.e-2) + + def test_interpolate_parallel_region_singlenull(self, tmpdir_factory, + bout_xyt_example_files): + path = bout_xyt_example_files(tmpdir_factory, lengths=(2, 3, 16, 3), nxpe=1, + nype=3, nt=1, grid='grid', guards={'y': 2}, + topology='single-null') + + ds = open_boutdataset(datapath=path, + gridfilepath=Path(path).parent.joinpath('grid.nc'), + geometry='toroidal', keep_yboundaries=True) + + n = ds['n'] + + thetalength = 2.*np.pi + + dtheta = thetalength/48. + theta = xr.DataArray(np.linspace(0. - 1.5*dtheta, thetalength + 1.5*dtheta, 52), + dims='theta') + + dtheta_fine = thetalength/3./128. + theta_fine = xr.DataArray( + np.linspace(0. + 0.5*dtheta_fine, thetalength - 0.5*dtheta_fine, 3*128), + dims='theta') + + def f(t): + t = np.sin(3.*t) + return (t**3 - t**2 + t - 1.) + + n.data = f(theta).broadcast_like(n) + + f_fine = f(theta_fine)[:128] + + for region in ['inner_PFR', 'inner_SOL']: + n_highres = n.bout.interpolate_parallel(region).isel(theta=slice(2, None)) + + expected = f_fine.broadcast_like(n_highres) + + npt.assert_allclose(n_highres.values, expected.values, rtol=0., atol=1.e-2) + + for region in ['core', 'SOL']: + n_highres = n.bout.interpolate_parallel(region) + + expected = f_fine.broadcast_like(n_highres) + + npt.assert_allclose(n_highres.values, expected.values, rtol=0., atol=1.e-2) + + for region in ['outer_PFR', 'outer_SOL']: + n_highres = n.bout.interpolate_parallel(region).isel(theta=slice(-2)) + + expected = f_fine.broadcast_like(n_highres) + + npt.assert_allclose(n_highres.values, expected.values, rtol=0., atol=1.e-2) + + def test_interpolate_parallel(self, tmpdir_factory, bout_xyt_example_files): + path = bout_xyt_example_files(tmpdir_factory, lengths=(2, 3, 16, 3), nxpe=1, + nype=3, nt=1, grid='grid', guards={'y': 2}, + topology='single-null') + + ds = open_boutdataset(datapath=path, + gridfilepath=Path(path).parent.joinpath('grid.nc'), + geometry='toroidal', keep_yboundaries=True) + + n = ds['n'] + + thetalength = 2.*np.pi + + dtheta = thetalength/48. + theta = xr.DataArray(np.linspace(0. - 1.5*dtheta, thetalength + 1.5*dtheta, 52), + dims='theta') + + dtheta_fine = thetalength/3./128. + theta_fine = xr.DataArray( + np.linspace(0. + 0.5*dtheta_fine, thetalength - 0.5*dtheta_fine, 3*128), + dims='theta') + x = xr.DataArray(np.arange(3), dims='x') + + def f_y(t): + t = np.sin(3.*t) + return (t**3 - t**2 + t - 1.) + + f = f_y(theta) * (x + 1.) + + n.data = f.broadcast_like(n) + + f_fine = f_y(theta_fine)*(x + 1.) + + n_highres = n.bout.interpolate_parallel().isel(theta=slice(2, -2)) + + expected = f_fine.broadcast_like(n_highres) + + npt.assert_allclose(n_highres.values, expected.values, + rtol=0., atol=1.1e-2) + + def test_interpolate_parallel_sol(self, tmpdir_factory, bout_xyt_example_files): + path = bout_xyt_example_files(tmpdir_factory, lengths=(2, 3, 16, 3), nxpe=1, + nype=1, nt=1, grid='grid', guards={'y': 2}, + topology='sol') + + ds = open_boutdataset(datapath=path, + gridfilepath=Path(path).parent.joinpath('grid.nc'), + geometry='toroidal', keep_yboundaries=True) + + n = ds['n'] + + thetalength = 2.*np.pi + + dtheta = thetalength/16. + theta = xr.DataArray(np.linspace(0. - 1.5*dtheta, thetalength + 1.5*dtheta, 20), + dims='theta') + + dtheta_fine = thetalength/128. + theta_fine = xr.DataArray( + np.linspace(0. + 0.5*dtheta_fine, thetalength - 0.5*dtheta_fine, 128), + dims='theta') + x = xr.DataArray(np.arange(3), dims='x') + + def f_y(t): + t = np.sin(t) + return (t**3 - t**2 + t - 1.) + + f = f_y(theta) * (x + 1.) + + n.data = f.broadcast_like(n) + + f_fine = f_y(theta_fine)*(x + 1.) + + n_highres = n.bout.interpolate_parallel().isel(theta=slice(2, -2)) + + expected = f_fine.broadcast_like(n_highres) + + npt.assert_allclose(n_highres.values, expected.values, + rtol=0., atol=1.1e-2) + + def test_interpolate_parallel_toroidal_points(self, tmpdir_factory, + bout_xyt_example_files): + path = bout_xyt_example_files(tmpdir_factory, lengths=(2, 3, 16, 3), nxpe=1, + nype=3, nt=1, grid='grid', guards={'y': 2}, + topology='single-null') + + ds = open_boutdataset(datapath=path, + gridfilepath=Path(path).parent.joinpath('grid.nc'), + geometry='toroidal', keep_yboundaries=True) + + n_highres = ds['n'].bout.interpolate_parallel() + + n_highres_truncated = ds['n'].bout.interpolate_parallel(toroidal_points=2) + + xrt.assert_identical(n_highres_truncated, n_highres.isel(zeta=[0, 2])) + + def test_interpolate_parallel_toroidal_points_list(self, tmpdir_factory, + bout_xyt_example_files): + path = bout_xyt_example_files(tmpdir_factory, lengths=(2, 3, 16, 3), nxpe=1, + nype=3, nt=1, grid='grid', guards={'y': 2}, + topology='single-null') + + ds = open_boutdataset(datapath=path, + gridfilepath=Path(path).parent.joinpath('grid.nc'), + geometry='toroidal', keep_yboundaries=True) + + n_highres = ds['n'].bout.interpolate_parallel() + + points_list = [1, 2] + + n_highres_truncated = ds['n'].bout.interpolate_parallel( + toroidal_points=points_list) + + xrt.assert_identical(n_highres_truncated, n_highres.isel(zeta=points_list)) diff --git a/xbout/tests/test_boutdataset.py b/xbout/tests/test_boutdataset.py index 64b13270..251bb117 100644 --- a/xbout/tests/test_boutdataset.py +++ b/xbout/tests/test_boutdataset.py @@ -1,13 +1,17 @@ import pytest +import numpy.testing as npt from xarray import Dataset, DataArray, concat, open_dataset, open_mfdataset import xarray.testing as xrt import numpy as np from pathlib import Path from xbout.tests.test_load import bout_xyt_example_files, create_bout_ds +from xbout.tests.test_region import (params_guards, params_guards_values, + params_boundaries, params_boundaries_values) from xbout import BoutDatasetAccessor, open_boutdataset, reload_boutdataset from xbout.geometries import apply_geometry +from xbout.utils import _set_attrs_on_all_vars EXAMPLE_OPTIONS_FILE_PATH = './xbout/tests/data/options/BOUT.inp' @@ -81,6 +85,378 @@ def test_getFieldAligned(self, tmpdir_factory, bout_xyt_example_files): ds['n_aligned'] = ds['T'] xrt.assert_allclose(ds.bout.getFieldAligned('n'), ds['T']) + def test_set_parallel_interpolation_factor(self): + ds = Dataset() + ds['a'] = DataArray() + ds = _set_attrs_on_all_vars(ds, 'metadata', {}) + + with pytest.raises(KeyError): + ds.metadata['fine_interpolation_factor'] + with pytest.raises(KeyError): + ds['a'].metadata['fine_interpolation_factor'] + + ds.bout.fine_interpolation_factor = 42 + + assert ds.metadata['fine_interpolation_factor'] == 42 + assert ds['a'].metadata['fine_interpolation_factor'] == 42 + + @pytest.mark.parametrize(params_guards, params_guards_values) + @pytest.mark.parametrize(params_boundaries, params_boundaries_values) + @pytest.mark.parametrize( + "vars_to_interpolate", [('n', 'T'), pytest.param(..., marks=pytest.mark.long)] + ) + def test_interpolate_parallel(self, tmpdir_factory, bout_xyt_example_files, + guards, keep_xboundaries, keep_yboundaries, + vars_to_interpolate): + # This test checks that the regions created in the new high-resolution Dataset by + # interpolate_parallel are correct. + # This test does not test the accuracy of the parallel interpolation (there are + # other tests for that). + + # Note using more than MXG x-direction points and MYG y-direction points per + # output file ensures tests for whether boundary cells are present do not fail + # when using minimal numbers of processors + path = bout_xyt_example_files(tmpdir_factory, lengths=(2, 3, 4, 3), nxpe=3, + nype=6, nt=1, guards=guards, grid='grid', + topology='disconnected-double-null') + + ds = open_boutdataset(datapath=path, + gridfilepath=Path(path).parent.joinpath('grid.nc'), + geometry='toroidal', keep_xboundaries=keep_xboundaries, + keep_yboundaries=keep_yboundaries) + + # Get high parallel resolution version of ds, and check that + ds = ds.bout.interpolate_parallel(vars_to_interpolate) + + mxg = guards['x'] + myg = guards['y'] + + if keep_xboundaries: + ixs1 = ds.metadata['ixseps1'] + else: + ixs1 = ds.metadata['ixseps1'] - guards['x'] + + if keep_xboundaries: + ixs2 = ds.metadata['ixseps2'] + else: + ixs2 = ds.metadata['ixseps2'] - guards['x'] + + if keep_yboundaries: + ybndry = guards['y'] + else: + ybndry = 0 + jys11 = ds.metadata['jyseps1_1'] + ybndry + jys21 = ds.metadata['jyseps2_1'] + ybndry + ny_inner = ds.metadata['ny_inner'] + 2*ybndry + jys12 = ds.metadata['jyseps1_2'] + 3*ybndry + jys22 = ds.metadata['jyseps2_2'] + 3*ybndry + ny = ds.metadata['ny'] + 4*ybndry + + for var in ['n', 'T']: + v = ds[var] + + v_lower_inner_PFR = v.bout.from_region('lower_inner_PFR') + + # Remove attributes that are expected to be different + del v_lower_inner_PFR.attrs['region'] + xrt.assert_identical(v.isel(x=slice(ixs1 + mxg), theta=slice(jys11 + 1)), + v_lower_inner_PFR.isel( + theta=slice(-myg if myg != 0 else None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs1 + mxg), + theta=slice(jys22 + 1, jys22 + 1 + myg)).values, + v_lower_inner_PFR.isel(theta=slice(-myg, None)).values) + + v_lower_inner_intersep = v.bout.from_region('lower_inner_intersep') + + # Remove attributes that are expected to be different + del v_lower_inner_intersep.attrs['region'] + xrt.assert_identical(v.isel(x=slice(ixs1 - mxg, ixs2 + mxg), + theta=slice(jys11 + 1)), + v_lower_inner_intersep.isel( + theta=slice(-myg if myg != 0 else None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs1 - mxg, ixs2 + mxg), + theta=slice(jys11 + 1, jys11 + 1 + myg)).values, + v_lower_inner_intersep.isel( + theta=slice(-myg, None)).values) + + v_lower_inner_SOL = v.bout.from_region('lower_inner_SOL') + + # Remove attributes that are expected to be different + del v_lower_inner_SOL.attrs['region'] + xrt.assert_identical(v.isel(x=slice(ixs2 - mxg, None), + theta=slice(jys11 + 1)), + v_lower_inner_SOL.isel( + theta=slice(-myg if myg != 0 else None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs2 - mxg, None), + theta=slice(jys11 + 1, jys11 + 1 + myg)).values, + v_lower_inner_SOL.isel(theta=slice(-myg, None)).values) + + v_inner_core = v.bout.from_region('inner_core') + + # Remove attributes that are expected to be different + del v_inner_core.attrs['region'] + xrt.assert_identical(v.isel(x=slice(ixs1 + mxg), + theta=slice(jys11 + 1, jys21 + 1)), + v_inner_core.isel( + theta=slice(myg, -myg if myg != 0 else None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs1 + mxg), + theta=slice(jys22 + 1 - myg, jys22 + 1)).values, + v_inner_core.isel(theta=slice(myg)).values) + npt.assert_equal(v.isel(x=slice(ixs1 + mxg), + theta=slice(jys12 + 1, jys12 + 1 + myg)).values, + v_inner_core.isel(theta=slice(-myg, None)).values) + + v_inner_intersep = v.bout.from_region('inner_intersep') + + # Remove attributes that are expected to be different + del v_inner_intersep.attrs['region'] + xrt.assert_identical(v.isel(x=slice(ixs1 - mxg, ixs2 + mxg), + theta=slice(jys11 + 1, jys21 + 1)), + v_inner_intersep.isel( + theta=slice(myg, -myg if myg != 0 else None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs1 - mxg, ixs2 + mxg), + theta=slice(jys11 + 1 - myg, jys11 + 1)).values, + v_inner_intersep.isel(theta=slice(myg)).values) + npt.assert_equal(v.isel(x=slice(ixs1 - mxg, ixs2 + mxg), + theta=slice(jys12 + 1, jys12 + 1 + myg)).values, + v_inner_intersep.isel(theta=slice(-myg, None)).values) + + v_inner_sol = v.bout.from_region('inner_SOL') + + # Remove attributes that are expected to be different + del v_inner_sol.attrs['region'] + xrt.assert_identical( + v.isel(x=slice(ixs2 - mxg, None), theta=slice(jys11 + 1, jys21 + 1)), + v_inner_sol.isel(theta=slice(myg, -myg if myg != 0 else None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs2 - mxg, None), + theta=slice(jys11 + 1 - myg, jys11 + 1)).values, + v_inner_sol.isel(theta=slice(myg)).values) + npt.assert_equal(v.isel(x=slice(ixs2 - mxg, None), + theta=slice(jys21 + 1, jys21 + 1 + myg)).values, + v_inner_sol.isel(theta=slice(-myg, None)).values) + + v_upper_inner_PFR = v.bout.from_region('upper_inner_PFR') + + # Remove attributes that are expected to be different + del v_upper_inner_PFR.attrs['region'] + xrt.assert_identical(v.isel(x=slice(ixs1 + mxg), + theta=slice(jys21 + 1, ny_inner)), + v_upper_inner_PFR.isel(theta=slice(myg, None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs1 + mxg), + theta=slice(jys12 + 1 - myg, jys12 + 1)).values, + v_upper_inner_PFR.isel(theta=slice(myg)).values) + + v_upper_inner_intersep = v.bout.from_region('upper_inner_intersep') + + # Remove attributes that are expected to be different + del v_upper_inner_intersep.attrs['region'] + xrt.assert_identical(v.isel(x=slice(ixs1 - mxg, ixs2 + mxg), + theta=slice(jys21 + 1, ny_inner)), + v_upper_inner_intersep.isel(theta=slice(myg, None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs1 - mxg, ixs2 + mxg), + theta=slice(jys12 + 1 - myg, jys12 + 1)).values, + v_upper_inner_intersep.isel(theta=slice(myg)).values) + + v_upper_inner_SOL = v.bout.from_region('upper_inner_SOL') + + # Remove attributes that are expected to be different + del v_upper_inner_SOL.attrs['region'] + xrt.assert_identical(v.isel(x=slice(ixs2 - mxg, None), + theta=slice(jys21 + 1, ny_inner)), + v_upper_inner_SOL.isel(theta=slice(myg, None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs2 - mxg, None), + theta=slice(jys21 + 1 - myg, jys21 + 1)).values, + v_upper_inner_SOL.isel(theta=slice(myg)).values) + + v_upper_outer_PFR = v.bout.from_region('upper_outer_PFR') + + # Remove attributes that are expected to be different + del v_upper_outer_PFR.attrs['region'] + xrt.assert_identical(v.isel(x=slice(ixs1 + mxg), + theta=slice(ny_inner, jys12 + 1)), + v_upper_outer_PFR.isel( + theta=slice(-myg if myg != 0 else None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs1 + mxg), + theta=slice(jys21 + 1, jys21 + 1 + myg)).values, + v_upper_outer_PFR.isel(theta=slice(-myg, None)).values) + + v_upper_outer_intersep = v.bout.from_region('upper_outer_intersep') + + # Remove attributes that are expected to be different + del v_upper_outer_intersep.attrs['region'] + xrt.assert_identical(v.isel(x=slice(ixs1 - mxg, ixs2 + mxg), + theta=slice(ny_inner, jys12 + 1)), + v_upper_outer_intersep.isel( + theta=slice(-myg if myg != 0 else None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs1 - mxg, ixs2 + mxg), + theta=slice(jys21 + 1, jys21 + 1 + myg)).values, + v_upper_outer_intersep.isel( + theta=slice(-myg, None)).values) + + v_upper_outer_SOL = v.bout.from_region('upper_outer_SOL') + + # Remove attributes that are expected to be different + del v_upper_outer_SOL.attrs['region'] + xrt.assert_identical(v.isel(x=slice(ixs2 - mxg, None), + theta=slice(ny_inner, jys12 + 1)), + v_upper_outer_SOL.isel( + theta=slice(-myg if myg != 0 else None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs2 - mxg, None), + theta=slice(jys12 + 1, jys12 + 1 + myg)).values, + v_upper_outer_SOL.isel(theta=slice(-myg, None)).values) + + v_outer_core = v.bout.from_region('outer_core') + + # Remove attributes that are expected to be different + del v_outer_core.attrs['region'] + xrt.assert_identical(v.isel(x=slice(ixs1 + mxg), + theta=slice(jys12 + 1, jys22 + 1)), + v_outer_core.isel( + theta=slice(myg, -myg if myg != 0 else None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs1 + mxg), + theta=slice(jys21 + 1 - myg, jys21 + 1)).values, + v_outer_core.isel(theta=slice(myg)).values) + npt.assert_equal(v.isel(x=slice(ixs1 + mxg), + theta=slice(jys11 + 1, jys11 + 1 + myg)).values, + v_outer_core.isel(theta=slice(-myg, None)).values) + + v_outer_intersep = v.bout.from_region('outer_intersep') + + # Remove attributes that are expected to be different + del v_outer_intersep.attrs['region'] + xrt.assert_identical(v.isel(x=slice(ixs1 - mxg, ixs2 + mxg), + theta=slice(jys12 + 1, jys22 + 1)), + v_outer_intersep.isel( + theta=slice(myg, -myg if myg != 0 else None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs1 - mxg, ixs2 + mxg), + theta=slice(jys21 + 1 - myg, jys21 + 1)).values, + v_outer_intersep.isel(theta=slice(myg)).values) + npt.assert_equal(v.isel(x=slice(ixs1 - mxg, ixs2 + mxg), + theta=slice(jys22 + 1, jys22 + 1 + myg)).values, + v_outer_intersep.isel(theta=slice(-myg, None)).values) + + v_outer_sol = v.bout.from_region('outer_SOL') + + # Remove attributes that are expected to be different + del v_outer_sol.attrs['region'] + xrt.assert_identical( + v.isel(x=slice(ixs2 - mxg, None), theta=slice(jys12 + 1, jys22 + 1)), + v_outer_sol.isel(theta=slice(myg, -myg if myg != 0 else None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs2 - mxg, None), + theta=slice(jys12 + 1 - myg, jys12 + 1)).values, + v_outer_sol.isel(theta=slice(myg)).values) + npt.assert_equal(v.isel(x=slice(ixs2 - mxg, None), + theta=slice(jys22 + 1, jys22 + 1 + myg)).values, + v_outer_sol.isel(theta=slice(-myg, None)).values) + + v_lower_outer_PFR = v.bout.from_region('lower_outer_PFR') + + # Remove attributes that are expected to be different + del v_lower_outer_PFR.attrs['region'] + xrt.assert_identical(v.isel(x=slice(ixs1 + mxg), + theta=slice(jys22 + 1, None)), + v_lower_outer_PFR.isel(theta=slice(myg, None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs1 + mxg), + theta=slice(jys11 + 1 - myg, jys11 + 1)).values, + v_lower_outer_PFR.isel(theta=slice(myg)).values) + + v_lower_outer_intersep = v.bout.from_region('lower_outer_intersep') + + # Remove attributes that are expected to be different + del v_lower_outer_intersep.attrs['region'] + xrt.assert_identical(v.isel(x=slice(ixs1 - mxg, ixs2 + mxg), + theta=slice(jys22 + 1, None)), + v_lower_outer_intersep.isel(theta=slice(myg, None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs1 - mxg, ixs2 + mxg), + theta=slice(jys22 + 1 - myg, jys22 + 1)).values, + v_lower_outer_intersep.isel(theta=slice(myg)).values) + + v_lower_outer_SOL = v.bout.from_region('lower_outer_SOL') + + # Remove attributes that are expected to be different + del v_lower_outer_SOL.attrs['region'] + xrt.assert_identical(v.isel(x=slice(ixs2 - mxg, None), + theta=slice(jys22 + 1, None)), + v_lower_outer_SOL.isel(theta=slice(myg, None))) + if myg > 0: + # check y-guards, which were 'communicated' by from_region + # Coordinates are not equal, so only compare array values + npt.assert_equal(v.isel(x=slice(ixs2 - mxg, None), + theta=slice(jys22 + 1 - myg, jys22 + 1)).values, + v_lower_outer_SOL.isel(theta=slice(myg)).values) + + def test_interpolate_parallel_all_variables_arg(self, tmpdir_factory, + bout_xyt_example_files): + # Check that passing 'variables=...' to interpolate_parallel() does actually + # interpolate all the variables + path = bout_xyt_example_files(tmpdir_factory, lengths=(2, 3, 4, 3), nxpe=1, + nype=1, nt=1, grid='grid', topology='sol') + + ds = open_boutdataset(datapath=path, + gridfilepath=Path(path).parent.joinpath('grid.nc'), + geometry='toroidal') + + # Get high parallel resolution version of ds, and check that + ds = ds.bout.interpolate_parallel(...) + + interpolated_variables = [v for v in ds] + + assert set(interpolated_variables) == set(( + 'n', 'T', 'g11', 'g22', 'g33', 'g12', 'g13', 'g23', 'g_11', 'g_22', 'g_33', + 'g_12', 'g_13', 'g_23', 'G1', 'G2', 'G3', 'J', 'Bxy', 'dx', 'dy' + )) + class TestLoadInputFile: @pytest.mark.skip @@ -149,16 +525,6 @@ def test_reload_all(self, tmpdir_factory, bout_xyt_example_files, geometry): # Load it again recovered = reload_boutdataset(savepath) - # Compare - for coord in original.coords.values(): - # Get rid of the options if they exist, because options are not dealt with - # totally consistently: they exist if a coord was created from a variable - # loaded from the BOUT++ output, but not if the coord was calculated from - # some parameters or loaded from a grid file - try: - del coord.attrs["options"] - except KeyError: - pass xrt.assert_identical(original.load(), recovered.load()) @pytest.mark.skip("saving and loading as float32 does not work") @@ -231,15 +597,6 @@ def test_reload_separate_variables( recovered = reload_boutdataset(savepath, pre_squashed=True) # Compare - for coord in original.coords.values(): - # Get rid of the options if they exist, because options are not dealt with - # totally consistently: they exist if a coord was created from a variable - # loaded from the BOUT++ output, but not if the coord was calculated from - # some parameters or loaded from a grid file - try: - del coord.attrs["options"] - except KeyError: - pass xrt.assert_identical(recovered, original) diff --git a/xbout/tests/test_geometries.py b/xbout/tests/test_geometries.py index 61ebc5b6..1ccb9f1e 100644 --- a/xbout/tests/test_geometries.py +++ b/xbout/tests/test_geometries.py @@ -1,3 +1,5 @@ +import numpy as np + from xarray import Dataset, DataArray from xarray.testing import assert_equal import pytest @@ -21,6 +23,8 @@ def add_schwarzschild_coords(ds, coordinates=None): assert "Schwarzschild" in REGISTERED_GEOMETRIES.keys() original = Dataset() + original['dy'] = DataArray(np.ones((3, 4)), dims=('x', 'y')) + original.attrs['metadata'] = {} updated = apply_geometry(ds=original, geometry_name="Schwarzschild") assert_equal(updated['event_horizon'], DataArray(4.0)) diff --git a/xbout/tests/test_grid.py b/xbout/tests/test_grid.py index a1e079f3..d30361b4 100644 --- a/xbout/tests/test_grid.py +++ b/xbout/tests/test_grid.py @@ -20,7 +20,8 @@ def create_example_grid_file(tmpdir_factory): # Create grid dataset arr = np.arange(6).reshape(2, 3) - grid = DataArray(data=arr, dims=['x', 'y']) + grid = DataArray(data=arr, name='arr', dims=['x', 'y']).to_dataset() + grid['dy'] = DataArray(np.ones((2, 3)), dims=['x', 'y']) # Create temporary directory save_dir = tmpdir_factory.mktemp("griddata") @@ -59,6 +60,7 @@ def test_open_grid_apply_geometry(self, create_example_grid_file): @register_geometry(name="Schwarzschild") def add_schwarzschild_coords(ds, coordinates=None): ds['event_horizon'] = 4.0 + ds['event_horizon'].attrs = ds.attrs.copy() return ds example_grid = create_example_grid_file diff --git a/xbout/tests/test_load.py b/xbout/tests/test_load.py index 97a40bbf..f9bd59ff 100644 --- a/xbout/tests/test_load.py +++ b/xbout/tests/test_load.py @@ -320,6 +320,9 @@ def create_bout_ds(syn_data_type='random', lengths=(6, 2, 4, 7), num=0, nxpe=1, T = DataArray(data, dims=['t', 'x', 'y', 'z']) n = DataArray(data, dims=['t', 'x', 'y', 'z']) + for v in [n, T]: + v.attrs['direction_y'] = 'Standard' + v.attrs['cell_location'] = 'CELL_CENTRE' ds = Dataset({'n': n, 'T': T}) # BOUT_VERSION needed so that we know that number of points in z is MZ, not MZ-1 (as @@ -372,8 +375,7 @@ def create_bout_ds(syn_data_type='random', lengths=(6, 2, 4, 7), num=0, nxpe=1, ds['ny_inner'] = ny//2 elif topology == 'xpoint': if nype < 4: - raise ValueError('Not enough processors for xpoint topology: ' - + 'nype=' + str(nype)) + raise ValueError(f'Not enough processors for xpoint topology: nype={nype}') ds['ixseps1'] = nx//2 ds['ixseps2'] = nx//2 ds['jyseps1_1'] = MYSUB - 1 @@ -384,8 +386,7 @@ def create_bout_ds(syn_data_type='random', lengths=(6, 2, 4, 7), num=0, nxpe=1, ds['jyseps2_2'] = ny - MYSUB - 1 elif topology == 'single-null': if nype < 3: - raise ValueError('Not enough processors for single-null topology: ' - + 'nype=' + str(nype)) + raise ValueError(f'Not enough processors for xpoint topology: nype={nype}') ds['ixseps1'] = nx//2 ds['ixseps2'] = nx ds['jyseps1_1'] = MYSUB - 1 @@ -396,7 +397,7 @@ def create_bout_ds(syn_data_type='random', lengths=(6, 2, 4, 7), num=0, nxpe=1, elif topology == 'connected-double-null': if nype < 6: raise ValueError('Not enough processors for connected-double-null topology: ' - + 'nype=' + str(nype)) + f'nype={nype}') ds['ixseps1'] = nx//2 ds['ixseps2'] = nx//2 ds['jyseps1_1'] = MYSUB - 1 @@ -408,12 +409,12 @@ def create_bout_ds(syn_data_type='random', lengths=(6, 2, 4, 7), num=0, nxpe=1, elif topology == 'disconnected-double-null': if nype < 6: raise ValueError('Not enough processors for disconnected-double-null ' - + 'topology: nype=' + str(nype)) + f'topology: nype={nype}') ds['ixseps1'] = nx//2 ds['ixseps2'] = nx//2 + 4 if ds['ixseps2'] >= nx: raise ValueError('Not enough points in the x-direction. ixseps2=' - + str(ds['ixseps2']) + ' > nx=' + str(nx)) + f'{ds["ixseps2"]} > nx={nx}') ds['jyseps1_1'] = MYSUB - 1 ny_inner = 3*MYSUB ds['ny_inner'] = ny_inner @@ -421,7 +422,7 @@ def create_bout_ds(syn_data_type='random', lengths=(6, 2, 4, 7), num=0, nxpe=1, ds['jyseps1_2'] = ny_inner + MYSUB - 1 ds['jyseps2_2'] = ny - MYSUB - 1 else: - raise ValueError('Unrecognised topology=' + str(topology)) + raise ValueError(f'Unrecognised topology={topology}') one = DataArray(np.ones((x_length, y_length)), dims=['x', 'y']) zero = DataArray(np.zeros((x_length, y_length)), dims=['x', 'y']) diff --git a/xbout/tests/test_region.py b/xbout/tests/test_region.py index 201aa219..de162127 100644 --- a/xbout/tests/test_region.py +++ b/xbout/tests/test_region.py @@ -9,15 +9,21 @@ from xbout import open_boutdataset -class TestRegion: +params_guards = "guards" +params_guards_values = [pytest.param({'x': 0, 'y': 0}, marks=pytest.mark.long), + pytest.param({'x': 2, 'y': 0}, marks=pytest.mark.long), + pytest.param({'x': 0, 'y': 2}, marks=pytest.mark.long), + {'x': 2, 'y': 2}] +params_boundaries = "keep_xboundaries, keep_yboundaries" +params_boundaries_values = [pytest.param(False, False, marks=pytest.mark.long), + pytest.param(True, False, marks=pytest.mark.long), + pytest.param(False, True, marks=pytest.mark.long), + (True, True)] + - params_guards = "guards" - params_guards_values = [{'x': 0, 'y': 0}, {'x': 2, 'y': 0}, {'x': 0, 'y': 2}, - {'x': 2, 'y': 2}] - params_boundaries = "keep_xboundaries, keep_yboundaries" - params_boundaries_values = [(False, False), (True, False), (False, True), - (True, True)] +class TestRegion: + @pytest.mark.long @pytest.mark.parametrize(params_guards, params_guards_values) @pytest.mark.parametrize(params_boundaries, params_boundaries_values) def test_region_core(self, tmpdir_factory, bout_xyt_example_files, guards, @@ -53,6 +59,7 @@ def test_region_core(self, tmpdir_factory, bout_xyt_example_files, guards, n.isel(theta=slice(ybndry, -ybndry if ybndry != 0 else None)), n_core.isel(theta=slice(ybndry, -ybndry if ybndry != 0 else None))) + @pytest.mark.long @pytest.mark.parametrize(params_guards, params_guards_values) @pytest.mark.parametrize(params_boundaries, params_boundaries_values) def test_region_sol(self, tmpdir_factory, bout_xyt_example_files, guards, @@ -133,6 +140,7 @@ def test_region_limiter(self, tmpdir_factory, bout_xyt_example_files, guards, theta=slice(ybndry, -ybndry if ybndry != 0 else None)), n_core.isel(theta=slice(ybndry, -ybndry if ybndry != 0 else None))) + @pytest.mark.long @pytest.mark.parametrize(params_guards, params_guards_values) @pytest.mark.parametrize(params_boundaries, params_boundaries_values) def test_region_xpoint(self, tmpdir_factory, bout_xyt_example_files, guards, @@ -282,6 +290,7 @@ def test_region_xpoint(self, tmpdir_factory, bout_xyt_example_files, guards, theta=slice(jys2 + 1 - myg, jys2 + 1)).values, n_lower_outer_SOL.isel(theta=slice(myg)).values) + @pytest.mark.long @pytest.mark.parametrize(params_guards, params_guards_values) @pytest.mark.parametrize(params_boundaries, params_boundaries_values) def test_region_singlenull(self, tmpdir_factory, bout_xyt_example_files, guards, @@ -403,6 +412,7 @@ def test_region_singlenull(self, tmpdir_factory, bout_xyt_example_files, guards, theta=slice(jys2 + 1 - myg, jys2 + 1)).values, n_outer_SOL.isel(theta=slice(myg)).values) + @pytest.mark.long @pytest.mark.parametrize(params_guards, params_guards_values) @pytest.mark.parametrize(params_boundaries, params_boundaries_values) def test_region_connecteddoublenull(self, tmpdir_factory, bout_xyt_example_files, @@ -996,7 +1006,8 @@ def test_region_disconnecteddoublenull_get_one_guard( n = ds['n'] - n_lower_inner_PFR = n.bout.from_region('lower_inner_PFR', with_guards=with_guards) + n_lower_inner_PFR = n.bout.from_region('lower_inner_PFR', + with_guards=with_guards) # Remove attributes that are expected to be different del n_lower_inner_PFR.attrs['region'] @@ -1011,7 +1022,7 @@ def test_region_disconnecteddoublenull_get_one_guard( n_lower_inner_PFR.isel(theta=slice(-yguards, None)).values) n_lower_inner_intersep = n.bout.from_region('lower_inner_intersep', - with_guards=with_guards) + with_guards=with_guards) # Remove attributes that are expected to be different del n_lower_inner_intersep.attrs['region'] @@ -1027,7 +1038,8 @@ def test_region_disconnecteddoublenull_get_one_guard( theta=slice(jys11 + 1, jys11 + 1 + yguards)).values, n_lower_inner_intersep.isel(theta=slice(-yguards, None)).values) - n_lower_inner_SOL = n.bout.from_region('lower_inner_SOL', with_guards=with_guards) + n_lower_inner_SOL = n.bout.from_region('lower_inner_SOL', + with_guards=with_guards) # Remove attributes that are expected to be different del n_lower_inner_SOL.attrs['region'] @@ -1096,7 +1108,8 @@ def test_region_disconnecteddoublenull_get_one_guard( theta=slice(jys21 + 1, jys21 + 1 + yguards)).values, n_inner_sol.isel(theta=slice(-yguards, None)).values) - n_upper_inner_PFR = n.bout.from_region('upper_inner_PFR', with_guards=with_guards) + n_upper_inner_PFR = n.bout.from_region('upper_inner_PFR', + with_guards=with_guards) # Remove attributes that are expected to be different del n_upper_inner_PFR.attrs['region'] @@ -1111,7 +1124,7 @@ def test_region_disconnecteddoublenull_get_one_guard( n_upper_inner_PFR.isel(theta=slice(yguards)).values) n_upper_inner_intersep = n.bout.from_region('upper_inner_intersep', - with_guards=with_guards) + with_guards=with_guards) # Remove attributes that are expected to be different del n_upper_inner_intersep.attrs['region'] @@ -1125,7 +1138,8 @@ def test_region_disconnecteddoublenull_get_one_guard( theta=slice(jys12 + 1 - yguards, jys12 + 1)).values, n_upper_inner_intersep.isel(theta=slice(yguards)).values) - n_upper_inner_SOL = n.bout.from_region('upper_inner_SOL', with_guards=with_guards) + n_upper_inner_SOL = n.bout.from_region('upper_inner_SOL', + with_guards=with_guards) # Remove attributes that are expected to be different del n_upper_inner_SOL.attrs['region'] @@ -1139,7 +1153,8 @@ def test_region_disconnecteddoublenull_get_one_guard( theta=slice(jys21 + 1 - yguards, jys21 + 1)).values, n_upper_inner_SOL.isel(theta=slice(yguards)).values) - n_upper_outer_PFR = n.bout.from_region('upper_outer_PFR', with_guards=with_guards) + n_upper_outer_PFR = n.bout.from_region('upper_outer_PFR', + with_guards=with_guards) # Remove attributes that are expected to be different del n_upper_outer_PFR.attrs['region'] @@ -1155,7 +1170,7 @@ def test_region_disconnecteddoublenull_get_one_guard( n_upper_outer_PFR.isel(theta=slice(-yguards, None)).values) n_upper_outer_intersep = n.bout.from_region('upper_outer_intersep', - with_guards=with_guards) + with_guards=with_guards) # Remove attributes that are expected to be different del n_upper_outer_intersep.attrs['region'] @@ -1171,7 +1186,8 @@ def test_region_disconnecteddoublenull_get_one_guard( n_upper_outer_intersep.isel( theta=slice(-yguards, None)).values) - n_upper_outer_SOL = n.bout.from_region('upper_outer_SOL', with_guards=with_guards) + n_upper_outer_SOL = n.bout.from_region('upper_outer_SOL', + with_guards=with_guards) # Remove attributes that are expected to be different del n_upper_outer_SOL.attrs['region'] @@ -1241,7 +1257,8 @@ def test_region_disconnecteddoublenull_get_one_guard( theta=slice(jys22 + 1, jys22 + 1 + yguards)).values, n_outer_sol.isel(theta=slice(-yguards, None)).values) - n_lower_outer_PFR = n.bout.from_region('lower_outer_PFR', with_guards=with_guards) + n_lower_outer_PFR = n.bout.from_region('lower_outer_PFR', + with_guards=with_guards) # Remove attributes that are expected to be different del n_lower_outer_PFR.attrs['region'] @@ -1256,7 +1273,7 @@ def test_region_disconnecteddoublenull_get_one_guard( n_lower_outer_PFR.isel(theta=slice(yguards)).values) n_lower_outer_intersep = n.bout.from_region('lower_outer_intersep', - with_guards=with_guards) + with_guards=with_guards) # Remove attributes that are expected to be different del n_lower_outer_intersep.attrs['region'] @@ -1270,7 +1287,8 @@ def test_region_disconnecteddoublenull_get_one_guard( theta=slice(jys22 + 1 - yguards, jys22 + 1)).values, n_lower_outer_intersep.isel(theta=slice(yguards)).values) - n_lower_outer_SOL = n.bout.from_region('lower_outer_SOL', with_guards=with_guards) + n_lower_outer_SOL = n.bout.from_region('lower_outer_SOL', + with_guards=with_guards) # Remove attributes that are expected to be different del n_lower_outer_SOL.attrs['region'] @@ -1280,6 +1298,7 @@ def test_region_disconnecteddoublenull_get_one_guard( if yguards > 0: # check y-guards, which were 'communicated' by from_region # Coordinates are not equal, so only compare array values - npt.assert_equal(n.isel(x=slice(ixs2 - xguards, None), - theta=slice(jys22 + 1 - yguards, jys22 + 1)).values, - n_lower_outer_SOL.isel(theta=slice(yguards)).values) + npt.assert_equal( + n.isel(x=slice(ixs2 - xguards, None), + theta=slice(jys22 + 1 - yguards, jys22 + 1)).values, + n_lower_outer_SOL.isel(theta=slice(yguards)).values) diff --git a/xbout/tests/test_utils.py b/xbout/tests/test_utils.py index c973368e..546e5c9d 100644 --- a/xbout/tests/test_utils.py +++ b/xbout/tests/test_utils.py @@ -2,7 +2,7 @@ from xarray import Dataset, DataArray -from xbout.utils import _set_attrs_on_all_vars +from xbout.utils import _set_attrs_on_all_vars, _update_metadata_increased_resolution class TestUtils: @@ -36,3 +36,26 @@ def test__set_attrs_on_all_vars_copy(self): assert ds.metadata['x'] == 5 assert ds['a'].metadata['x'] == 3 assert ds['b'].metadata['x'] == 3 + + def test__update_metadata_increased_resolution(self): + da = DataArray() + da.attrs['metadata'] = { + 'jyseps1_1': 1, + 'jyseps2_1': 2, + 'ny_inner': 3, + 'jyseps1_2': 4, + 'jyseps2_2': 5, + 'ny': 6, + 'MYSUB': 7, + } + + da = _update_metadata_increased_resolution(da, 3) + + assert da.metadata['jyseps1_1'] == 5 + assert da.metadata['jyseps2_1'] == 8 + assert da.metadata['jyseps1_2'] == 14 + assert da.metadata['jyseps2_2'] == 17 + + assert da.metadata['ny_inner'] == 9 + assert da.metadata['ny'] == 18 + assert da.metadata['MYSUB'] == 21 diff --git a/xbout/utils.py b/xbout/utils.py index 5237a902..d190b460 100644 --- a/xbout/utils.py +++ b/xbout/utils.py @@ -1,4 +1,5 @@ from copy import deepcopy +from itertools import chain import numpy as np @@ -6,14 +7,25 @@ def _set_attrs_on_all_vars(ds, key, attr_data, copy=False): ds.attrs[key] = attr_data if copy: - for da in ds.values(): - da.attrs[key] = deepcopy(attr_data) + for v in chain(ds.data_vars, ds.coords): + ds[v].attrs[key] = deepcopy(attr_data) else: - for da in ds.values(): - da.attrs[key] = attr_data + for v in chain(ds.data_vars, ds.coords): + ds[v].attrs[key] = attr_data return ds +def _add_attrs_to_var(ds, varname, copy=False): + if copy: + for attr in ["metadata", "options", "geometry", "regions"]: + if attr in ds.attrs and attr not in ds[varname].attrs: + ds[varname].attrs[attr] = deepcopy(ds.attrs[attr]) + else: + for attr in ["metadata", "options", "geometry", "regions"]: + if attr in ds.attrs and attr not in ds[varname].attrs: + ds[varname].attrs[attr] = ds.attrs[attr] + + def _check_filetype(path): if path.suffix == '.nc': filetype = 'netcdf4' @@ -42,3 +54,42 @@ def _separate_metadata(ds): metadata = dict(zip(scalar_vars, metadata_vals)) return ds.drop(scalar_vars), metadata + + +def _update_metadata_increased_resolution(da, n): + """ + Update the metadata variables to account for a y-direction resolution increased by a + factor n. + + Parameters + ---------- + da : DataArray + The variable to update + n : int + The factor to increase the y-resolution by + """ + + # Take deepcopy to ensure we do not alter metadata of other variables + da.attrs['metadata'] = deepcopy(da.metadata) + + def update_jyseps(name): + # If any jyseps<=0, need to leave as is + if da.metadata[name] > 0: + da.metadata[name] = n*(da.metadata[name] + 1) - 1 + update_jyseps('jyseps1_1') + update_jyseps('jyseps2_1') + update_jyseps('jyseps1_2') + update_jyseps('jyseps2_2') + + def update_ny(name): + da.metadata[name] = n*da.metadata[name] + update_ny('ny') + update_ny('ny_inner') + update_ny('MYSUB') + + # Update attrs of coordinates to be consistent with da + for coord in da.coords: + da[coord].attrs = {} + _add_attrs_to_var(da, coord) + + return da