Skip to content

Commit

Permalink
cache_tmp_dir for convertor
Browse files Browse the repository at this point in the history
  • Loading branch information
EddyCMWF committed Dec 20, 2024
1 parent c52e459 commit 95bdd08
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 10 deletions.
3 changes: 1 addition & 2 deletions cads_adaptors/adaptors/mars.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,13 @@ def retrieve_list_of_results(self, request: dict[str, Any]) -> list[str]:
with dask.config.set(scheduler="threads"):
results_dict = self.post_process(result)

target_dir = str(self.cache_tmp_path)
# TODO?: Generalise format conversion to be a post-processor
paths = self.convert_format(
results_dict,
self.data_format,
context=self.context,
config=self.config,
to_netcdf_kwargs={"target_dir": target_dir},
target_dir=str(self.cache_tmp_path),
)

# A check to ensure that if there is more than one path, and download_format
Expand Down
2 changes: 1 addition & 1 deletion cads_adaptors/adaptors/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def retrieve_list_of_results(self, request: Request) -> list[str]:
self.data_format,
self.context,
self.config,
to_netcdf_kwargs={"target_dir": self.cache_tmp_path},
target_dir=str(self.cache_tmp_path),
)

if len(paths) > 1 and self.download_format == "as_source":
Expand Down
10 changes: 5 additions & 5 deletions cads_adaptors/tools/convertors.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,17 @@ def convert_format(
target_format: str,
context: Context = Context(),
config: dict[str, Any] = {},
target_dir: str = ".",
**runtime_kwargs: dict[str, dict[str, Any]],
) -> list[str]:
target_format = adaptor_tools.handle_data_format(target_format)
post_processing_kwargs = config.get("post_processing_kwargs", {})
for k, v in runtime_kwargs.items():
post_processing_kwargs.setdefault(k, {}).update(v)
post_processing_kwargs.setdefault("target_dir", target_dir)
context.add_stdout(
f"Converting result ({result}) to {target_format} with kwargs: {post_processing_kwargs}"
)
for k, v in runtime_kwargs.items():
post_processing_kwargs.setdefault(k, {}).update(v)

convertor: None | Callable = CONVERTORS.get(target_format, None)

if convertor is not None:
Expand Down Expand Up @@ -181,7 +182,7 @@ def result_to_netcdf_legacy_files(
result: Any,
context: Context = Context(),
to_netcdf_legacy_kwargs: dict[str, Any] = {},
target_dir: str = "",
target_dir: str = ".",
**kwargs,
) -> list[str]:
"""
Expand Down Expand Up @@ -381,7 +382,6 @@ def xarray_dict_to_netcdf(
"compression_options", compression_options
)
out_fname_prefix = to_netcdf_kwargs.pop("out_fname_prefix", out_fname_prefix)
target_dir = to_netcdf_kwargs.pop("target_dir", target_dir)

# Fetch any preset compression options
if isinstance(compression_options, str):
Expand Down
33 changes: 31 additions & 2 deletions tests/test_30_convertors.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ def test_grib_to_netcdf():
assert isinstance(netcdf_files, list)
assert len(netcdf_files) == 1

os.makedirs("test_subdir", exist_ok=True)
netcdf_files = convertors.grib_to_netcdf_files(
tmp_grib_file, target_dir="./test_subdir"
)
assert isinstance(netcdf_files, list)
assert "/test_subdir/" in netcdf_files[0]
assert len(netcdf_files) == 1

netcdf_files = convertors.grib_to_netcdf_files(
tmp_grib_file, compression_options="default"
)
Expand Down Expand Up @@ -142,7 +150,7 @@ def test_convert_format_to_netcdf(url, target_format="netcdf"):
_, ext = os.path.splitext(url)
with tempfile.TemporaryDirectory() as tmpdirname:
os.chdir(tmpdirname)
tmp_file = f"test.{ext}"
tmp_file = f"test{ext}"
with open(tmp_file, "wb") as f:
f.write(remote_file.content)

Expand All @@ -154,6 +162,17 @@ def test_convert_format_to_netcdf(url, target_format="netcdf"):
_, out_ext = os.path.splitext(converted_files[0])
assert out_ext == EXTENSION_MAPPING.get(target_format, f".{target_format}")

os.makedirs("test_subdir", exist_ok=True)
converted_files = convertors.convert_format(
tmp_file, target_format=target_format, target_dir="./test_subdir"
)
assert isinstance(converted_files, list)
assert len(converted_files) == 1
_, out_ext = os.path.splitext(converted_files[0])
assert out_ext == EXTENSION_MAPPING.get(target_format, f".{target_format}")
if out_ext != ext: # i.e. if a conversion has taken place
assert "/test_subdir/" in converted_files[0]


@pytest.mark.parametrize("url", [TEST_GRIB_FILE, TEST_NC_FILE])
def test_convert_format_to_grib(url, target_format="grib"):
Expand Down Expand Up @@ -182,7 +201,7 @@ def test_convert_format_to_netcdf_legacy(
_, ext = os.path.splitext(url)
with tempfile.TemporaryDirectory() as tmpdirname:
os.chdir(tmpdirname)
tmp_file = f"test.{ext}"
tmp_file = f"test{ext}"
with open(tmp_file, "wb") as f:
f.write(remote_file.content)

Expand All @@ -194,6 +213,16 @@ def test_convert_format_to_netcdf_legacy(
_, out_ext = os.path.splitext(converted_files[0])
assert out_ext == EXTENSION_MAPPING.get(target_format, f".{target_format}")

os.makedirs("test_subdir", exist_ok=True)
converted_files = convertors.convert_format(
tmp_file, target_format=target_format, target_dir="./test_subdir"
)
assert isinstance(converted_files, list)
assert len(converted_files) == 1
_, out_ext = os.path.splitext(converted_files[0])
assert out_ext == EXTENSION_MAPPING.get(target_format, f".{target_format}")
if out_ext != ext: # i.e. if a conversion has taken place
assert "/test_subdir/" in converted_files[0]

def test_safely_rename_variable():
import xarray as xr
Expand Down

0 comments on commit 95bdd08

Please sign in to comment.