From 2db1426aba0942c7086f1d8c835229fe48391ed6 Mon Sep 17 00:00:00 2001 From: Sylvain Brunato Date: Tue, 17 Dec 2024 11:45:17 +0100 Subject: [PATCH] test: to_xarray using fsspec --- eodag_cube/api/product/_product.py | 11 +- eodag_cube/utils/__init__.py | 27 ++-- eodag_cube/utils/xarray.py | 8 +- setup.cfg | 1 + setup.py | 1 + tests/context.py | 18 ++- tests/units/test_eoproduct.py | 243 +++++++++++++++++++++++------ tests/utils.py | 29 ++++ 8 files changed, 268 insertions(+), 70 deletions(-) diff --git a/eodag_cube/api/product/_product.py b/eodag_cube/api/product/_product.py index 1d11e1f..d4aeaf1 100644 --- a/eodag_cube/api/product/_product.py +++ b/eodag_cube/api/product/_product.py @@ -275,7 +275,7 @@ def _get_storage_options( auth = self.downloader_auth.authenticate() if self.downloader_auth else None # order if product is offline - if self.properties["storageStatus"] == OFFLINE_STATUS and hasattr( + if self.properties.get("storageStatus") == OFFLINE_STATUS and hasattr( self.downloader, "order" ): self.downloader.order(self, auth, wait=wait, timeout=timeout) @@ -351,14 +351,20 @@ def to_xarray( """ if asset_key is None and len(self.assets) > 0: # assets + + # have roles been set in assets ? + roles_exist = any("roles" in a for a in self.assets) + xd = XarrayDict() - with concurrent.futures.ThreadPoolExecutor() as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: futures = ( executor.submit(self.to_xarray, key, wait, timeout, **xarray_kwargs) for key, asset in self.assets.items() if roles and asset.get("roles") and any(r in asset["roles"] for r in roles) + or not roles + or not roles_exist ) for future in concurrent.futures.as_completed(futures): try: @@ -381,6 +387,7 @@ def to_xarray( except ( UnsupportedDatasetAddressScheme, FileNotFoundError, + IsADirectoryError, DatasetCreationError, ) as e: logger.debug(f"Cannot open {self} {asset_key if asset_key else ''}: {e}") diff --git a/eodag_cube/utils/__init__.py b/eodag_cube/utils/__init__.py index 2ee12e0..a87453b 100644 --- a/eodag_cube/utils/__init__.py +++ b/eodag_cube/utils/__init__.py @@ -47,24 +47,23 @@ def fsspec_file_headers(file: OpenFile) -> Optional[dict[str, Any]]: :param file: fsspec https OpenFile :returns: file headers or ``None`` """ - headers = None + file_kwargs = getattr(file, "kwargs", {}) if "https" in file.fs.protocol: try: - resp = requests.head(file.path, **file.kwargs) + resp = requests.head(file.path, **file_kwargs) resp.raise_for_status() except requests.RequestException: pass else: - headers = resp.headers - if not headers: - # if HEAD method is not available, try to get a minimal part of the file - try: - resp = requests.get(file.path, stream=True, **file.kwargs) - resp.raise_for_status() - except requests.RequestException: - pass - else: - headers = resp.headers + return resp.headers + # if HEAD method is not available, try to get a minimal part of the file + try: + resp = requests.get(file.path, **file_kwargs) + resp.raise_for_status() + except requests.RequestException: + pass + else: + return resp.headers return None @@ -84,8 +83,8 @@ def fsspec_file_extension(file: OpenFile) -> Optional[str]: Optional[str], parse_header(content_disposition).get_param("filename", None), ) - _, extension = os.path.splitext(filename) if filename else None, None - if extension: + _, extension = os.path.splitext(filename) if filename else (None, None) + if not extension: mime_type = headers.get("content-type", "").split(";")[0] if mime_type not in IGNORED_MIMETYPES: extension = guess_extension(mime_type) diff --git a/eodag_cube/utils/xarray.py b/eodag_cube/utils/xarray.py index bf6ce32..d1dadc5 100644 --- a/eodag_cube/utils/xarray.py +++ b/eodag_cube/utils/xarray.py @@ -26,7 +26,6 @@ import fsspec import rioxarray import xarray as xr -from fsspec.implementations.local import LocalFileOpener from eodag_cube.types import XarrayDict from eodag_cube.utils import fsspec_file_extension @@ -61,8 +60,9 @@ def try_open_dataset(file: OpenFile, **xarray_kwargs: dict[str, Any]) -> xr.Data :param file: fsspec https OpenFile :param xarray_kwargs: (optional) keyword arguments passed to xarray.open_dataset :returns: opened xarray dataset - """ + LOCALFILE_ONLY_ENGINES = ["netcdf4", "cfgrib"] + if engine := xarray_kwargs.pop("engine", None): all_engines = [ engine, @@ -70,7 +70,7 @@ def try_open_dataset(file: OpenFile, **xarray_kwargs: dict[str, Any]) -> xr.Data else: all_engines = guess_engines(file) or list(xr.backends.list_engines().keys()) - if isinstance(file, LocalFileOpener): + if "file" in file.fs.protocol: engines = all_engines # use path str as cfgrib does not support fsspec OpenFile as input @@ -93,7 +93,7 @@ def try_open_dataset(file: OpenFile, **xarray_kwargs: dict[str, Any]) -> xr.Data else: # remove engines that do not support remote access # https://tutorial.xarray.dev/intermediate/remote_data/remote-data.html#supported-format-read-from-buffers-remote-access - engines = [eng for eng in all_engines if eng not in ["netcdf4", "cfgrib"]] + engines = [eng for eng in all_engines if eng not in LOCALFILE_ONLY_ENGINES] file_or_path = file diff --git a/setup.cfg b/setup.cfg index 6c4ee7c..dd6b634 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,6 +21,7 @@ line_length=88 known_first_party = eodag,tests known_third_party = concurrent.futures default_section = THIRDPARTY +ensure_newline_before_comments = True skip = .git, __pycache__, diff --git a/setup.py b/setup.py index 2dc2739..3cef587 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,7 @@ "faker", "coverage", "moto >= 5", + "responses < 0.24.0", "twine", "wheel", ] diff --git a/tests/context.py b/tests/context.py index 27b13a5..3b48e88 100644 --- a/tests/context.py +++ b/tests/context.py @@ -31,11 +31,27 @@ from eodag_cube.api.product.drivers.generic import GenericDriver from eodag_cube.api.product.drivers.sentinel2_l1c import Sentinel2L1C from eodag_cube.api.product.drivers.stac_assets import StacAssets +from eodag_cube.utils import fsspec_file_headers, fsspec_file_extension +from eodag_cube.utils.exceptions import DatasetCreationError +from eodag_cube.utils.xarray import ( + guess_engines, + try_open_dataset, + build_local_xarray_dict, +) from eodag.plugins.authentication.base import Authentication from eodag.plugins.authentication.aws_auth import AwsAuth +from eodag.plugins.authentication.header import HTTPHeaderAuth +from eodag.plugins.authentication.qsauth import HttpQueryStringAuth from eodag.plugins.download.base import Download from eodag.plugins.download.aws import AwsDownload -from eodag.utils import DEFAULT_PROJ, path_to_uri +from eodag.utils import ( + DEFAULT_PROJ, + path_to_uri, + USER_AGENT, + DEFAULT_DOWNLOAD_TIMEOUT, + DEFAULT_DOWNLOAD_WAIT, + path_to_uri, +) from eodag.utils.exceptions import ( AddressNotFound, DownloadError, diff --git a/tests/units/test_eoproduct.py b/tests/units/test_eoproduct.py index 9a39e7c..8d16797 100644 --- a/tests/units/test_eoproduct.py +++ b/tests/units/test_eoproduct.py @@ -19,30 +19,26 @@ import itertools import os import random -import shutil -import urllib.request -from pathlib import Path -from tempfile import TemporaryDirectory import numpy as np import xarray as xr from rasterio.session import AWSSession -from eodag_cube.types import XarrayDict -from tests import ( - TEST_GRIB_FILE_PATH, - TEST_GRIB_FILENAME, - TEST_RESOURCES_PATH, - EODagTestCase, -) +from tests import TEST_GRIB_FILE_PATH, TEST_GRIB_FILENAME, EODagTestCase from tests.context import ( + DEFAULT_DOWNLOAD_TIMEOUT, + DEFAULT_DOWNLOAD_WAIT, DEFAULT_PROJ, + USER_AGENT, Authentication, AwsAuth, AwsDownload, + DatasetCreationError, Download, DownloadError, EOProduct, + HTTPHeaderAuth, + HttpQueryStringAuth, NoDriver, PluginConfig, Sentinel2L1C, @@ -250,44 +246,193 @@ def test_get_rio_env(self): self.assertEqual(rio_env["AWS_S3_ENDPOINT"], "some.where") self.assertEqual(rio_env["AWS_VIRTUAL_HOSTING"], "FALSE") - def populate_directory_with_heterogeneous_files(self, destination): - """ - Put various files in the destination directory: - - a NetCDF file - - a JPEG2000 file - - an XML file - """ - # Copy all files from a grib product - cams_air_quality_product_path = os.path.join( - TEST_RESOURCES_PATH, - "products", - "cams-europe-air-quality-forecasts", - ) - shutil.copytree(cams_air_quality_product_path, destination, dirs_exist_ok=True) - - # Copy files from an S2A product - s2a_path = os.path.join( - TEST_RESOURCES_PATH, - "products", - "S2A_MSIL1C_20180101T105441_N0206_R051_T31TDH_20180101T124911.SAFE", - ) - shutil.copytree(s2a_path, destination, dirs_exist_ok=True) - - def test_build_xarray_dict(self): - with TemporaryDirectory(prefix="eodag-cube-tests") as tmp_dir: - product = EOProduct( - self.provider, self.eoproduct_props, productType=self.product_type - ) - product.location = "file:" + urllib.request.pathname2url(tmp_dir) - self.populate_directory_with_heterogeneous_files(tmp_dir) + def test_get_storage_options_http_headers(self): + """_get_storage_options should be adapted to the provider config""" + product = EOProduct( + self.provider, self.eoproduct_props, productType=self.product_type + ) + # http headers auth + product.register_downloader( + Download("foo", PluginConfig()), + HTTPHeaderAuth( + "foo", + PluginConfig.from_mapping( + { + "credentials": {"apikey": "foo"}, + "headers": {"X-API-Key": "{apikey}"}, + } + ), + ), + ) + self.assertDictEqual( + product._get_storage_options(), + { + "path": self.download_url, + "headers": {"X-API-Key": "foo", **USER_AGENT}, + }, + ) - xarray_dict = product._build_xarray_dict() + def test_get_storage_options_http_qs(self): + """_get_storage_options should be adapted to the provider config""" + product = EOProduct( + self.provider, self.eoproduct_props, productType=self.product_type + ) + # http qs auth + product.register_downloader( + Download("foo", PluginConfig()), + HttpQueryStringAuth( + "foo", + PluginConfig.from_mapping( + { + "credentials": {"apikey": "foo"}, + } + ), + ), + ) + self.assertDictEqual( + product._get_storage_options(), + { + "path": f"{self.download_url}?apikey=foo", + "headers": USER_AGENT, + }, + ) - self.assertIsInstance(xarray_dict, XarrayDict) - self.assertEqual(len(xarray_dict), 2) - for key, value in xarray_dict.items(): - self.assertIn(Path(key).suffix, {".nc", ".jp2"}) - self.assertIsInstance(value, xr.Dataset) + def test_get_storage_options_s3(self): + """_get_storage_options should be adapted to the provider config""" + product = EOProduct( + self.provider, self.eoproduct_props, productType=self.product_type + ) + # http s3 auth + product.register_downloader( + Download( + "foo", + PluginConfig.from_mapping( + { + "s3_endpoint": "http://foo.bar", + } + ), + ), + AwsAuth( + "foo", + PluginConfig.from_mapping( + { + "credentials": { + "aws_access_key_id": "foo", + "aws_secret_access_key": "bar", + "aws_session_token": "baz", + }, + } + ), + ), + ) + self.assertDictEqual( + product._get_storage_options(), + { + "path": self.download_url, + "key": "foo", + "secret": "bar", + "token": "baz", + "client_kwargs": {"endpoint_url": "http://foo.bar"}, + }, + ) - for ds in xarray_dict.values(): - ds.close() + def test_get_storage_options_error(self): + """_get_storage_options should be adapted to the provider config""" + product = EOProduct( + self.provider, self.eoproduct_props, productType=self.product_type + ) + with self.assertRaises( + DatasetCreationError, msg=f"foo not found in {product} assets" + ): + product._get_storage_options(asset_key="foo") + + @mock.patch("eodag_cube.api.product._product.fsspec.filesystem") + @mock.patch( + "eodag_cube.api.product._product.EOProduct._get_storage_options", autospec=True + ) + def test_get_fsspec_file(self, mock_storage_options, mock_fs): + """get_fsspec_file should call fsspec open with appropriate args""" + product = EOProduct( + self.provider, self.eoproduct_props, productType=self.product_type + ) + # https + mock_storage_options.return_value = {"path": "https://foo.bar", "baz": "qux"} + file = product.get_fsspec_file() + mock_fs.assert_called_once_with("https", baz="qux") + mock_fs.return_value.open.assert_called_once_with(path="https://foo.bar") + self.assertEqual(file, mock_fs.return_value.open.return_value) + mock_fs.reset_mock() + # s3 + mock_storage_options.return_value = {"path": "s3://foo.bar", "baz": "qux"} + file = product.get_fsspec_file() + mock_fs.assert_called_once_with("s3", baz="qux") + mock_fs.return_value.open.assert_called_once_with(path="s3://foo.bar") + self.assertEqual(file, mock_fs.return_value.open.return_value) + mock_fs.reset_mock() + # local + mock_storage_options.return_value = { + "path": os.path.join("foo", "bar"), + "baz": "qux", + } + file = product.get_fsspec_file() + mock_fs.assert_called_once_with("file", baz="qux") + mock_fs.return_value.open.assert_called_once_with( + path=os.path.join("foo", "bar") + ) + self.assertEqual(file, mock_fs.return_value.open.return_value) + mock_fs.reset_mock() + # not found + mock_storage_options.return_value = {"baz": "qux"} + with self.assertRaises( + UnsupportedDatasetAddressScheme, msg=f"Could not get {product} path" + ): + product.get_fsspec_file() + + @mock.patch("eodag_cube.api.product._product.try_open_dataset", autospec=True) + @mock.patch( + "eodag_cube.api.product._product.EOProduct.get_fsspec_file", autospec=True + ) + def test_to_xarray(self, mock_get_file, mock_open_ds): + """to_xarrray should return well built XarrayDict""" + product = EOProduct( + self.provider, self.eoproduct_props, productType=self.product_type + ) + mock_open_ds.return_value = xr.Dataset() + mock_get_file.return_value.path = "http://foo.bar" + xd = product.to_xarray(foo="bar") + mock_get_file.assert_called_once_with( + product, None, DEFAULT_DOWNLOAD_WAIT, DEFAULT_DOWNLOAD_TIMEOUT + ) + mock_open_ds.assert_called_once_with(mock_get_file.return_value, foo="bar") + self.assertEqual(len(xd), 1) + self.assertTrue(xd["data"].equals(mock_open_ds.return_value)) + + @mock.patch("eodag_cube.api.product._product.try_open_dataset", autospec=True) + @mock.patch( + "eodag_cube.api.product._product.EOProduct.get_fsspec_file", autospec=True + ) + def test_to_xarray_assets(self, mock_get_file, mock_open_ds): + """to_xarrray should return well built XarrayDict""" + product = EOProduct( + self.provider, self.eoproduct_props, productType=self.product_type + ) + product.assets.update( + {"foo": {"href": "http://foo.bar"}}, + ) + product.assets.update( + {"bar": {"href": "http://bar.baz"}}, + ) + + mock_open_ds.return_value = xr.Dataset() + mock_get_file.return_value.path = "http://foo.bar" + xd = product.to_xarray(foo="bar") + mock_get_file.assert_any_call( + product, "foo", DEFAULT_DOWNLOAD_WAIT, DEFAULT_DOWNLOAD_TIMEOUT + ) + mock_get_file.assert_any_call( + product, "bar", DEFAULT_DOWNLOAD_WAIT, DEFAULT_DOWNLOAD_TIMEOUT + ) + mock_open_ds.assert_called_with(mock_get_file.return_value, foo="bar") + self.assertEqual(len(xd), 2) + self.assertTrue(xd["foo"].equals(mock_open_ds.return_value)) + self.assertTrue(xd["bar"].equals(mock_open_ds.return_value)) diff --git a/tests/utils.py b/tests/utils.py index 0bff2f4..66bd272 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,9 +16,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import shutil + # All tests files should import mock from this place from unittest import mock # noqa +from tests import TEST_RESOURCES_PATH + def no_blanks(string): """Removes all the blanks in string @@ -29,3 +34,27 @@ def no_blanks(string): :returns the same string with all blank characters removed """ return string.replace("\n", "").replace("\t", "").replace(" ", "") + + +def populate_directory_with_heterogeneous_files(destination): + """ + Put various files in the destination directory: + - a NetCDF file + - a JPEG2000 file + - an XML file + """ + # Copy all files from a grib product + cams_air_quality_product_path = os.path.join( + TEST_RESOURCES_PATH, + "products", + "cams-europe-air-quality-forecasts", + ) + shutil.copytree(cams_air_quality_product_path, destination, dirs_exist_ok=True) + + # Copy files from an S2A product + s2a_path = os.path.join( + TEST_RESOURCES_PATH, + "products", + "S2A_MSIL1C_20180101T105441_N0206_R051_T31TDH_20180101T124911.SAFE", + ) + shutil.copytree(s2a_path, destination, dirs_exist_ok=True)