Skip to content

Commit

Permalink
Merge pull request #2259 from ghiggi/refactor-cfwriter
Browse files Browse the repository at this point in the history
  • Loading branch information
mraspaud authored Jun 27, 2023
2 parents fcea851 + e0d622c commit 8fb627b
Show file tree
Hide file tree
Showing 8 changed files with 1,165 additions and 497 deletions.
124 changes: 124 additions & 0 deletions satpy/_scene_converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright (c) 2023 Satpy developers
#
# This file is part of satpy.
#
# satpy is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# satpy is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# satpy. If not, see <http://www.gnu.org/licenses/>.
"""Helper functions for converting the Scene object to some other object."""

import xarray as xr

from satpy.dataset import DataID


def _get_dataarrays_from_identifiers(scn, identifiers):
"""Return a list of DataArray based on a single or list of identifiers.
An identifier can be a DataID or a string with name of a valid DataID.
"""
if isinstance(identifiers, (str, DataID)):
identifiers = [identifiers]

if identifiers is not None:
dataarrays = [scn[ds] for ds in identifiers]
else:
dataarrays = [scn._datasets.get(ds) for ds in scn._wishlist]
dataarrays = [dataarray for dataarray in dataarrays if dataarray is not None]
return dataarrays


def to_xarray(scn,
datasets=None, # DataID
header_attrs=None,
exclude_attrs=None,
flatten_attrs=False,
pretty=True,
include_lonlats=True,
epoch=None,
include_orig_name=True,
numeric_name_prefix='CHANNEL_'):
"""Merge all xr.DataArray(s) of a satpy.Scene to a CF-compliant xarray object.
If all Scene DataArrays are on the same area, it returns an xr.Dataset.
If Scene DataArrays are on different areas, currently it fails, although
in future we might return a DataTree object, grouped by area.
Parameters
----------
scn: satpy.Scene
Satpy Scene.
datasets (iterable):
List of Satpy Scene datasets to include in the output xr.Dataset.
Elements can be string name, a wavelength as a number, a DataID,
or DataQuery object.
If None (the default), it include all loaded Scene datasets.
header_attrs:
Global attributes of the output xr.Dataset.
epoch (str):
Reference time for encoding the time coordinates (if available).
Example format: "seconds since 1970-01-01 00:00:00".
If None, the default reference time is retrieved using "from satpy.cf_writer import EPOCH"
flatten_attrs (bool):
If True, flatten dict-type attributes.
exclude_attrs (list):
List of xr.DataArray attribute names to be excluded.
include_lonlats (bool):
If True, it includes 'latitude' and 'longitude' coordinates.
If the 'area' attribute is a SwathDefinition, it always includes
latitude and longitude coordinates.
pretty (bool):
Don't modify coordinate names, if possible. Makes the file prettier,
but possibly less consistent.
include_orig_name (bool).
Include the original dataset name as a variable attribute in the xr.Dataset.
numeric_name_prefix (str):
Prefix to add the each variable with name starting with a digit.
Use '' or None to leave this out.
Returns
-------
ds, xr.Dataset
A CF-compliant xr.Dataset
"""
from satpy.writers.cf_writer import EPOCH, collect_cf_datasets

if epoch is None:
epoch = EPOCH

# Get list of DataArrays
if datasets is None:
datasets = list(scn.keys()) # list all loaded DataIDs
list_dataarrays = _get_dataarrays_from_identifiers(scn, datasets)

# Check that some DataArray could be returned
if len(list_dataarrays) == 0:
return xr.Dataset()

# Collect xr.Dataset for each group
grouped_datasets, header_attrs = collect_cf_datasets(list_dataarrays=list_dataarrays,
header_attrs=header_attrs,
exclude_attrs=exclude_attrs,
flatten_attrs=flatten_attrs,
pretty=pretty,
include_lonlats=include_lonlats,
epoch=epoch,
include_orig_name=include_orig_name,
numeric_name_prefix=numeric_name_prefix,
groups=None)
if len(grouped_datasets) == 1:
ds = grouped_datasets[None]
return ds
else:
msg = """The Scene object contains datasets with different areas.
Resample the Scene to have matching dimensions using i.e. scn.resample(resampler="native") """
raise NotImplementedError(msg)
79 changes: 70 additions & 9 deletions satpy/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,9 @@ def to_xarray_dataset(self, datasets=None):
Returns: :class:`xarray.Dataset`
"""
dataarrays = self._get_dataarrays_from_identifiers(datasets)
from satpy._scene_converters import _get_dataarrays_from_identifiers

dataarrays = _get_dataarrays_from_identifiers(self, datasets)

if len(dataarrays) == 0:
return xr.Dataset()
Expand All @@ -1098,13 +1100,70 @@ def to_xarray_dataset(self, datasets=None):
ds.attrs = mdata
return ds

def _get_dataarrays_from_identifiers(self, identifiers):
if identifiers is not None:
dataarrays = [self[ds] for ds in identifiers]
else:
dataarrays = [self._datasets.get(ds) for ds in self._wishlist]
dataarrays = [ds for ds in dataarrays if ds is not None]
return dataarrays
def to_xarray(self,
datasets=None, # DataID
header_attrs=None,
exclude_attrs=None,
flatten_attrs=False,
pretty=True,
include_lonlats=True,
epoch=None,
include_orig_name=True,
numeric_name_prefix='CHANNEL_'):
"""Merge all xr.DataArray(s) of a satpy.Scene to a CF-compliant xarray object.
If all Scene DataArrays are on the same area, it returns an xr.Dataset.
If Scene DataArrays are on different areas, currently it fails, although
in future we might return a DataTree object, grouped by area.
Parameters
----------
datasets (iterable):
List of Satpy Scene datasets to include in the output xr.Dataset.
Elements can be string name, a wavelength as a number, a DataID,
or DataQuery object.
If None (the default), it include all loaded Scene datasets.
header_attrs:
Global attributes of the output xr.Dataset.
epoch (str):
Reference time for encoding the time coordinates (if available).
Example format: "seconds since 1970-01-01 00:00:00".
If None, the default reference time is retrieved using "from satpy.cf_writer import EPOCH"
flatten_attrs (bool):
If True, flatten dict-type attributes.
exclude_attrs (list):
List of xr.DataArray attribute names to be excluded.
include_lonlats (bool):
If True, it includes 'latitude' and 'longitude' coordinates.
If the 'area' attribute is a SwathDefinition, it always includes
latitude and longitude coordinates.
pretty (bool):
Don't modify coordinate names, if possible. Makes the file prettier,
but possibly less consistent.
include_orig_name (bool).
Include the original dataset name as a variable attribute in the xr.Dataset.
numeric_name_prefix (str):
Prefix to add the each variable with name starting with a digit.
Use '' or None to leave this out.
Returns
-------
ds, xr.Dataset
A CF-compliant xr.Dataset
"""
from satpy._scene_converters import to_xarray

return to_xarray(scn=self,
datasets=datasets, # DataID
header_attrs=header_attrs,
exclude_attrs=exclude_attrs,
flatten_attrs=flatten_attrs,
pretty=pretty,
include_lonlats=include_lonlats,
epoch=epoch,
include_orig_name=include_orig_name,
numeric_name_prefix=numeric_name_prefix)

def images(self):
"""Generate images for all the datasets from the scene."""
Expand Down Expand Up @@ -1205,7 +1264,9 @@ def save_datasets(self, writer=None, filename=None, datasets=None, compute=True,
close any objects that have a "close" method.
"""
dataarrays = self._get_dataarrays_from_identifiers(datasets)
from satpy._scene_converters import _get_dataarrays_from_identifiers

dataarrays = _get_dataarrays_from_identifiers(self, datasets)
if not dataarrays:
raise RuntimeError("None of the requested datasets have been "
"generated or could not be loaded. Requested "
Expand Down
84 changes: 84 additions & 0 deletions satpy/tests/scene_tests/test_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,87 @@ def test_geoviews_basic_with_swath(self):
gv_obj = scn.to_geoviews()
# we assume that if we got something back, geoviews can use it
assert gv_obj is not None


class TestToXarrayConversion:
"""Test Scene.to_xarray() conversion."""

def test_with_empty_scene(self):
"""Test converting empty Scene to xarray."""
scn = Scene()
ds = scn.to_xarray()
assert isinstance(ds, xr.Dataset)
assert len(ds.variables) == 0
assert len(ds.coords) == 0

@pytest.fixture
def single_area_scn(self):
"""Define Scene with single area."""
from pyresample.geometry import AreaDefinition

area = AreaDefinition('test', 'test', 'test',
{'proj': 'geos', 'lon_0': -95.5, 'h': 35786023.0},
2, 2, [-200, -200, 200, 200])
data_array = xr.DataArray(da.zeros((2, 2), chunks=-1),
dims=('y', 'x'),
attrs={'start_time': datetime(2018, 1, 1), 'area': area})
scn = Scene()
scn['var1'] = data_array
return scn

@pytest.fixture
def multi_area_scn(self):
"""Define Scene with multiple area."""
from pyresample.geometry import AreaDefinition

area1 = AreaDefinition('test', 'test', 'test',
{'proj': 'geos', 'lon_0': -95.5, 'h': 35786023.0},
2, 2, [-200, -200, 200, 200])
area2 = AreaDefinition('test', 'test', 'test',
{'proj': 'geos', 'lon_0': -95.5, 'h': 35786023.0},
4, 4, [-200, -200, 200, 200])

data_array1 = xr.DataArray(da.zeros((2, 2), chunks=-1),
dims=('y', 'x'),
attrs={'start_time': datetime(2018, 1, 1), 'area': area1})
data_array2 = xr.DataArray(da.zeros((4, 4), chunks=-1),
dims=('y', 'x'),
attrs={'start_time': datetime(2018, 1, 1), 'area': area2})
scn = Scene()
scn['var1'] = data_array1
scn['var2'] = data_array2
return scn

def test_with_single_area_scene_type(self, single_area_scn):
"""Test converting single area Scene to xarray dataset."""
ds = single_area_scn.to_xarray()
assert isinstance(ds, xr.Dataset)
assert "var1" in ds.data_vars

def test_include_lonlats_true(self, single_area_scn):
"""Test include lonlats."""
ds = single_area_scn.to_xarray(include_lonlats=True)
assert "latitude" in ds.coords
assert "longitude" in ds.coords

def test_include_lonlats_false(self, single_area_scn):
"""Test exclude lonlats."""
ds = single_area_scn.to_xarray(include_lonlats=False)
assert "latitude" not in ds.coords
assert "longitude" not in ds.coords

def test_dataset_string_accepted(self, single_area_scn):
"""Test accept dataset string."""
ds = single_area_scn.to_xarray(datasets="var1")
assert isinstance(ds, xr.Dataset)

def test_wrong_dataset_key(self, single_area_scn):
"""Test raise error if unexisting dataset."""
with pytest.raises(KeyError):
_ = single_area_scn.to_xarray(datasets="var2")

def test_to_xarray_with_multiple_area_scene(self, multi_area_scn):
"""Test converting muiltple area Scene to xarray."""
# TODO: in future adapt for DataTree implementation
with pytest.raises(ValueError):
_ = multi_area_scn.to_xarray()
Loading

0 comments on commit 8fb627b

Please sign in to comment.