Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use self.cache_tmp_dir for MarsCdsAdaptor #253

Merged
merged 13 commits into from
Dec 20, 2024
22 changes: 16 additions & 6 deletions cads_adaptors/adaptors/mars.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pathlib
from typing import Any, BinaryIO

from cads_adaptors.adaptors import Context, Request, cds
Expand Down Expand Up @@ -49,10 +50,11 @@ def get_mars_server_list(config) -> list[str]:

def execute_mars(
request: dict[str, Any] | list[dict[str, Any]],
context: Context,
context: Context = Context(),
config: dict[str, Any] = dict(),
mapping: dict[str, Any] = dict(),
target: str = "data.grib",
target_fname: str = "data.grib",
target_dir: str | pathlib.Path = "",
) -> str:
from cads_mars_server import client as mars_client

Expand All @@ -63,6 +65,8 @@ def execute_mars(
if config.get("embargo") is not None:
requests, _cacheable = implement_embargo(requests, config["embargo"])

target = str(pathlib.Path(target_dir) / target_fname)

split_on_keys = ALWAYS_SPLIT_ON + ensure_list(config.get("split_on", []))
requests = split_requests_on_keys(requests, split_on_keys, context, mapping)

Expand Down Expand Up @@ -118,7 +122,11 @@ class DirectMarsCdsAdaptor(cds.AbstractCdsAdaptor):
resources = {"MARS_CLIENT": 1}

def retrieve(self, request: Request) -> BinaryIO:
result = execute_mars(request, context=self.context)
result = execute_mars(
request,
context=self.context,
target_dir=self.cache_tmp_path,
)
return open(result, "rb")


Expand Down Expand Up @@ -178,22 +186,24 @@ def retrieve_list_of_results(self, request: dict[str, Any]) -> list[str]:
# Call normalise_request to set self.mapped_requests
request = self.normalise_request(request)

result: Any = execute_mars(
result = execute_mars(
self.mapped_requests,
context=self.context,
config=self.config,
mapping=self.mapping,
target_dir=self.cache_tmp_path,
)

with dask.config.set(scheduler="threads"):
result = self.post_process(result)
results_dict = self.post_process(result)

# TODO?: Generalise format conversion to be a post-processor
paths = self.convert_format(
result,
results_dict,
self.data_format,
context=self.context,
config=self.config,
target_dir=str(self.cache_tmp_path),
)

# A check to ensure that if there is more than one path, and download_format
Expand Down
7 changes: 6 additions & 1 deletion cads_adaptors/adaptors/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,16 @@ def retrieve_list_of_results(self, request: Request) -> list[str]:
context=self.context,
config=self.config,
mapping=self.mapping,
target_dir=self.cache_tmp_path,
)

with dask.config.set(scheduler="threads"):
paths = self.convert_format(
result, self.data_format, self.context, self.config
result,
self.data_format,
self.context,
self.config,
target_dir=str(self.cache_tmp_path),
)

if len(paths) > 1 and self.download_format == "as_source":
Expand Down
4 changes: 4 additions & 0 deletions cads_adaptors/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class RoocsValueError(ValueError):
"""Raised when a ROOCS request fails due to a value error."""


class CdsFormatConversionError(RuntimeError):
"""Raised when a CDS post-processing request fails."""


class CdsConfigurationError(ValueError):
"""Raised when a CDS request fails due to a configuration error."""

Expand Down
45 changes: 21 additions & 24 deletions cads_adaptors/tools/convertors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import xarray as xr

from cads_adaptors.adaptors import Context
from cads_adaptors.exceptions import CdsFormatConversionError
from cads_adaptors.tools import adaptor_tools
from cads_adaptors.tools.general import ensure_list

Expand Down Expand Up @@ -35,7 +36,7 @@
def add_user_log_and_raise_error(
message: str,
context: Context = Context(),
thisError=ValueError,
thisError=CdsFormatConversionError,
) -> NoReturn:
context.add_user_visible_error(message)
raise thisError(message)
Expand All @@ -46,13 +47,17 @@ def convert_format(
target_format: str,
context: Context = Context(),
config: dict[str, Any] = {},
target_dir: str = ".",
**runtime_kwargs: dict[str, dict[str, Any]],
) -> list[str]:
target_format = adaptor_tools.handle_data_format(target_format)
post_processing_kwargs = config.get("post_processing_kwargs", {})
for k, v in runtime_kwargs.items():
post_processing_kwargs.setdefault(k, {}).update(v)
post_processing_kwargs.setdefault("target_dir", target_dir)
context.add_stdout(
f"Converting result ({result}) to {target_format} with kwargs: {post_processing_kwargs}"
)

convertor: None | Callable = CONVERTORS.get(target_format, None)

if convertor is not None:
Expand Down Expand Up @@ -124,7 +129,6 @@ def result_to_grib_files(
add_user_log_and_raise_error(
f"Unable to convert result of type {result_type} to grib files. result:\n{result}",
context=context,
thisError=ValueError,
)


Expand Down Expand Up @@ -171,14 +175,14 @@ def result_to_netcdf_files(
add_user_log_and_raise_error(
f"Unable to convert result of type {result_type} to netCDF files. result:\n{result}",
context=context,
thisError=ValueError,
)


def result_to_netcdf_legacy_files(
result: Any,
context: Context = Context(),
to_netcdf_legacy_kwargs: dict[str, Any] = {},
target_dir: str = ".",
**kwargs,
) -> list[str]:
"""
Expand Down Expand Up @@ -238,7 +242,6 @@ def result_to_netcdf_legacy_files(
add_user_log_and_raise_error(
f"Unable to convert result of type {type(result)} to 'netcdf_legacy' files. result:\n{result}",
context=context,
thisError=ValueError,
)

if filter_rules:
Expand All @@ -265,7 +268,7 @@ def result_to_netcdf_legacy_files(

nc_files = []
for out_fname_base, grib_file in result.items():
out_fname = f"{out_fname_base}.nc"
out_fname = os.path.join(target_dir, f"{out_fname_base}.nc")
nc_files.append(out_fname)
command = ensure_list(command)
os.system(" ".join(command + ["-o", out_fname, grib_file]))
Expand All @@ -275,7 +278,7 @@ def result_to_netcdf_legacy_files(
"We are unable to convert this GRIB data to netCDF, "
"please download as GRIB and convert to netCDF locally.\n"
)
add_user_log_and_raise_error(message, context=context, thisError=RuntimeError)
add_user_log_and_raise_error(message, context=context)

return nc_files

Expand All @@ -302,9 +305,7 @@ def unknown_filetype_to_grib_files(
)
return [infile]
else:
add_user_log_and_raise_error(
f"Unknown file type: {infile}", context=context, thisError=ValueError
)
add_user_log_and_raise_error(f"Unknown file type: {infile}", context=context)


def unknown_filetype_to_netcdf_files(
Expand All @@ -320,25 +321,21 @@ def unknown_filetype_to_netcdf_files(
context.add_stdout(f"Converting {infile} to netCDF files with kwargs: {kwargs}")
return grib_to_netcdf_files(infile, context=context, **kwargs)
else:
add_user_log_and_raise_error(
f"Unknown file type: {infile}", context=context, thisError=ValueError
)
add_user_log_and_raise_error(f"Unknown file type: {infile}", context=context)


def grib_to_netcdf_files(
grib_file: str,
open_datasets_kwargs: None | dict[str, Any] | list[dict[str, Any]] = None,
post_open_datasets_kwargs: dict[str, Any] = {},
to_netcdf_kwargs: dict[str, Any] = {},
context: Context = Context(),
**kwargs,
):
to_netcdf_kwargs.update(kwargs.pop("to_netcdf_kwargs", {}))
grib_file = os.path.realpath(grib_file)

context.add_stdout(
f"Converting {grib_file} to netCDF files with:\n"
f"to_netcdf_kwargs: {to_netcdf_kwargs}\n"
f"to_netcdf_kwargs: {kwargs}\n"
f"open_datasets_kwargs: {open_datasets_kwargs}\n"
f"post_open_datasets_kwargs: {post_open_datasets_kwargs}\n"
)
Expand All @@ -357,11 +354,9 @@ def grib_to_netcdf_files(
)
context.add_user_visible_error(message=message)
context.add_stderr(message=message)
raise RuntimeError(message)
raise CdsFormatConversionError(message)

out_nc_files = xarray_dict_to_netcdf(
datasets, context=context, to_netcdf_kwargs=to_netcdf_kwargs
)
out_nc_files = xarray_dict_to_netcdf(datasets, context=context, **kwargs)

return out_nc_files

Expand All @@ -372,12 +367,16 @@ def xarray_dict_to_netcdf(
compression_options: str | dict[str, Any] = "default",
to_netcdf_kwargs: dict[str, Any] = {},
out_fname_prefix: str = "",
target_dir: str = "",
**kwargs,
) -> list[str]:
"""
Convert a dictionary of xarray datasets to netCDF files, where the key of the dictionary
is used in the filename.
"""
# Untangle any nested kwargs (I don't think this is necessary anymore)
to_netcdf_kwargs.update(kwargs.pop("to_netcdf_kwargs", {}))

# Check if compression_options or out_fname_prefix have been provided in to_netcdf_kwargs
compression_options = to_netcdf_kwargs.pop(
"compression_options", compression_options
Expand All @@ -396,7 +395,7 @@ def xarray_dict_to_netcdf(
"encoding": {var: compression_options for var in dataset},
}
)
out_fname = f"{out_fname_prefix}{out_fname_base}.nc"
out_fname = os.path.join(target_dir, f"{out_fname_prefix}{out_fname_base}.nc")
context.add_stdout(f"Writing {out_fname} with kwargs:\n{to_netcdf_kwargs}")
dataset.to_netcdf(out_fname, **to_netcdf_kwargs)
out_nc_files.append(out_fname)
Expand Down Expand Up @@ -435,7 +434,6 @@ def open_result_as_xarray_dictionary(
add_user_log_and_raise_error(
f"Unable to open result as an xarray dataset: \n{result}",
context=context,
thisError=ValueError,
)


Expand All @@ -457,7 +455,6 @@ def open_file_as_xarray_dictionary(
add_user_log_and_raise_error(
f"Unable to open file {infile} as an xarray dataset.",
context=context,
thisError=ValueError,
)


Expand Down Expand Up @@ -487,7 +484,7 @@ def safely_rename_variable(dataset: xr.Dataset, rename: dict[str, str]) -> xr.Da
if (new_name not in rename_order) or (
rename_order.index(conflict) > rename_order.index(new_name)
):
raise ValueError(
raise CdsFormatConversionError(
f"Refusing to to rename to existing variable name: {conflict}->{new_name}"
)

Expand Down
43 changes: 43 additions & 0 deletions tests/test_20_adaptor_multi.py → tests/test_15_adaptor_multi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import os

import requests

from cads_adaptors import AbstractAdaptor
from cads_adaptors.adaptors import multi

TEST_GRIB_FILE = (
"https://get.ecmwf.int/repository/test-data/cfgrib/era5-levels-members.grib"
)

FORM = {
"level": ["500", "850"],
"time": ["12:00", "00:00"],
Expand Down Expand Up @@ -177,3 +185,38 @@ def test_multi_adaptor_split_adaptors_dont_split_keys():
assert "dont_split" in sub_adaptors["mean"][1].keys()
assert "dont_split" not in sub_adaptors["max"][1].keys()
assert "area" in sub_adaptors["max"][1].keys()


def test_convert_format(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
multi_adaptor = multi.MultiMarsCdsAdaptor({}, {})

assert hasattr(multi_adaptor, "convert_format")

url = TEST_GRIB_FILE
remote_file = requests.get(url)
_, ext = os.path.splitext(url)

tmp_file = f"test{ext}"
with open(tmp_file, "wb") as f:
f.write(remote_file.content)

converted_files = multi_adaptor.convert_format(
tmp_file,
"netcdf",
)
assert isinstance(converted_files, list)
assert len(converted_files) == 1
_, out_ext = os.path.splitext(converted_files[0])
assert out_ext == ".nc"

test_subdir = "./test_subdir"
os.makedirs(test_subdir, exist_ok=True)
converted_files = multi_adaptor.convert_format(
tmp_file, "netcdf", target_dir=test_subdir
)
assert isinstance(converted_files, list)
assert len(converted_files) == 1
_, out_ext = os.path.splitext(converted_files[0])
assert out_ext == ".nc"
assert "/test_subdir/" in converted_files[0]
41 changes: 41 additions & 0 deletions tests/test_15_mars.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import os

import requests

from cads_adaptors.adaptors import mars

TEST_GRIB_FILE = (
"https://get.ecmwf.int/repository/test-data/cfgrib/era5-levels-members.grib"
)


def test_get_mars_servers():
mars_servers = mars.get_mars_server_list(
Expand All @@ -24,3 +30,38 @@ def test_get_mars_servers_envvar():
mars_servers = mars.get_mars_server_list({})
assert len(mars_servers) == 1
assert mars_servers[0] == "http://a-test-server.url"


def test_convert_format(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
mars_adaptor = mars.MarsCdsAdaptor({}, {})

assert hasattr(mars_adaptor, "convert_format")

url = TEST_GRIB_FILE
remote_file = requests.get(url)
_, ext = os.path.splitext(url)

tmp_file = f"test{ext}"
with open(tmp_file, "wb") as f:
f.write(remote_file.content)

converted_files = mars_adaptor.convert_format(
tmp_file,
"netcdf",
)
assert isinstance(converted_files, list)
assert len(converted_files) == 1
_, out_ext = os.path.splitext(converted_files[0])
assert out_ext == ".nc"

test_subdir = "./test_subdir"
os.makedirs(test_subdir, exist_ok=True)
converted_files = mars_adaptor.convert_format(
tmp_file, "netcdf", target_dir=test_subdir
)
assert isinstance(converted_files, list)
assert len(converted_files) == 1
_, out_ext = os.path.splitext(converted_files[0])
assert out_ext == ".nc"
assert "/test_subdir/" in converted_files[0]
Loading
Loading