Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid roundtrip to qcodes format when loading as xarray from netcdf #5627

Merged
merged 7 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/changes/newsfragments/5627.improved
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
When loading a QCoDeS dataset from a netcdf file using load_from_netcdf and converted to a Xarray dataset using
``to_xarray_dataset`` or ``cache.to_xarray_dataset`` we avoid converting the data to QCoDeS format and back to Xarray format.
This should safe time and avoid any potential corner cases when roundtripping the data.
jenshnielsen marked this conversation as resolved.
Show resolved Hide resolved
40 changes: 39 additions & 1 deletion src/qcodes/dataset/data_set_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import logging
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Generic, Literal, TypeVar

import numpy as np

from qcodes.dataset.descriptions.rundescriber import RunDescriber
from qcodes.dataset.exporters.export_info import ExportInfo
from qcodes.dataset.sqlite.connection import ConnectionPlus
from qcodes.dataset.sqlite.queries import completed, load_new_data_for_rundescriber

Expand All @@ -25,7 +27,7 @@

# used in forward refs that cannot be detected
from .data_set import DataSet # noqa F401
from .data_set_in_memory import DataSetInMem # noqa F401
from .data_set_in_memory import DataSetInMem
from .data_set_protocol import DataSetProtocol, ParameterData

DatasetType = TypeVar("DatasetType", bound="DataSetProtocol", covariant=True)
Expand Down Expand Up @@ -453,6 +455,42 @@ class DataSetCacheInMem(DataSetCache["DataSetInMem"]):
pass


class DataSetCacheDeferred(DataSetCacheInMem):
def __init__(self, dataset: DataSetInMem, loaded_data: Path | str):
super().__init__(dataset)
self._xr_dataset_path = Path(loaded_data)

def load_data_from_db(self) -> None:
if self._data == {}:
loaded_data = self._load_xr_dataset()
self._data = self._dataset._from_xarray_dataset_to_qcodes_raw_data(
loaded_data
)

def _load_xr_dataset(self) -> xr.Dataset:
import cf_xarray as cfxr
import xarray as xr

loaded_data = xr.load_dataset(self._xr_dataset_path, engine="h5netcdf")
loaded_data = cfxr.coding.decode_compress_to_multi_index(loaded_data)
export_info = ExportInfo.from_str(loaded_data.attrs.get("export_info", ""))
export_info.export_paths["nc"] = str(self._xr_dataset_path)
loaded_data.attrs["export_info"] = export_info.to_str()
return loaded_data

def to_xarray_dataset(
self, *, use_multi_index: Literal["auto", "always", "never"] = "auto"
) -> xr.Dataset:
loaded_data = self._load_xr_dataset()
if use_multi_index == "always":
ds = loaded_data.stack()
elif use_multi_index == "never":
ds = loaded_data.unstack()
else:
ds = loaded_data
return ds


class DataSetCacheWithDBBackend(DataSetCache["DataSet"]):
def load_data_from_db(self) -> None:
"""
Expand Down
128 changes: 62 additions & 66 deletions src/qcodes/dataset/data_set_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)
from qcodes.utils import NumpyJSONEncoder

from .data_set_cache import DataSetCacheInMem
from .data_set_cache import DataSetCacheDeferred, DataSetCacheInMem
from .dataset_helpers import _add_run_to_runs_table
from .descriptions.versioning import serialization as serial
from .experiment_settings import get_default_experiment_id
Expand Down Expand Up @@ -228,78 +228,74 @@ def _load_from_netcdf(
# in the code below floats and ints loaded from attributes are explicitly casted
# this is due to some older versions of qcodes writing them with a different backend
# reading them back results in a numpy array of one element
import cf_xarray as cfxr
import xarray as xr

loaded_data = xr.load_dataset(path, engine="h5netcdf")
with xr.open_dataset(path, engine="h5netcdf") as loaded_data:

loaded_data = cfxr.coding.decode_compress_to_multi_index(loaded_data)
parent_dataset_links = str_to_links(
loaded_data.attrs.get("parent_dataset_links", "[]")
)
if path_to_db is not None:
path_to_db = str(path_to_db)

parent_dataset_links = str_to_links(
loaded_data.attrs.get("parent_dataset_links", "[]")
)
if path_to_db is not None:
path_to_db = str(path_to_db)
with contextlib.closing(
conn_from_dbpath_or_conn(conn=None, path_to_db=path_to_db)
) as conn:
run_data = get_raw_run_attributes(conn, guid=loaded_data.guid)
path_to_db = conn.path_to_dbfile

with contextlib.closing(
conn_from_dbpath_or_conn(conn=None, path_to_db=path_to_db)
) as conn:
run_data = get_raw_run_attributes(conn, guid=loaded_data.guid)
path_to_db = conn.path_to_dbfile
if run_data is not None:
run_id = run_data["run_id"]
counter = run_data["counter"]
else:
run_id = int(loaded_data.captured_run_id)
counter = int(loaded_data.captured_counter)

if run_data is not None:
run_id = run_data["run_id"]
counter = run_data["counter"]
else:
run_id = int(loaded_data.captured_run_id)
counter = int(loaded_data.captured_counter)

path = str(path)
path = os.path.abspath(path)

export_info = ExportInfo.from_str(loaded_data.attrs.get("export_info", ""))
export_info.export_paths["nc"] = path
non_metadata = {
"run_timestamp_raw",
"completed_timestamp_raw",
"ds_name",
"exp_name",
"sample_name",
"export_info",
"parent_dataset_links",
}

metadata_keys = (
set(loaded_data.attrs.keys()) - set(RUNS_TABLE_COLUMNS) - non_metadata
)
metadata = {}
for key in metadata_keys:
data = loaded_data.attrs[key]
if isinstance(data, np.ndarray) and data.size == 1:
data = data[0]
metadata[str(key)] = data
path = str(path)
path = os.path.abspath(path)

ds = cls(
run_id=run_id,
captured_run_id=int(loaded_data.captured_run_id),
counter=counter,
captured_counter=int(loaded_data.captured_counter),
name=loaded_data.ds_name,
exp_id=0,
exp_name=loaded_data.exp_name,
sample_name=loaded_data.sample_name,
guid=loaded_data.guid,
path_to_db=path_to_db,
run_timestamp_raw=float(loaded_data.run_timestamp_raw),
completed_timestamp_raw=float(loaded_data.completed_timestamp_raw),
metadata=metadata,
rundescriber=serial.from_json_to_current(loaded_data.run_description),
parent_dataset_links=parent_dataset_links,
export_info=export_info,
snapshot=loaded_data.snapshot,
)
ds._cache = DataSetCacheInMem(ds)
ds._cache._data = cls._from_xarray_dataset_to_qcodes_raw_data(loaded_data)
export_info = ExportInfo.from_str(loaded_data.attrs.get("export_info", ""))
export_info.export_paths["nc"] = path
non_metadata = {
"run_timestamp_raw",
"completed_timestamp_raw",
"ds_name",
"exp_name",
"sample_name",
"export_info",
"parent_dataset_links",
}

metadata_keys = (
set(loaded_data.attrs.keys()) - set(RUNS_TABLE_COLUMNS) - non_metadata
)
metadata = {}
for key in metadata_keys:
data = loaded_data.attrs[key]
if isinstance(data, np.ndarray) and data.size == 1:
data = data[0]
metadata[str(key)] = data

ds = cls(
run_id=run_id,
captured_run_id=int(loaded_data.captured_run_id),
counter=counter,
captured_counter=int(loaded_data.captured_counter),
name=loaded_data.ds_name,
exp_id=0,
exp_name=loaded_data.exp_name,
sample_name=loaded_data.sample_name,
guid=loaded_data.guid,
path_to_db=path_to_db,
run_timestamp_raw=float(loaded_data.run_timestamp_raw),
completed_timestamp_raw=float(loaded_data.completed_timestamp_raw),
metadata=metadata,
rundescriber=serial.from_json_to_current(loaded_data.run_description),
parent_dataset_links=parent_dataset_links,
export_info=export_info,
snapshot=loaded_data.snapshot,
)
ds._cache = DataSetCacheDeferred(ds, path)

return ds

Expand Down
50 changes: 48 additions & 2 deletions tests/dataset/test_dataset_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pandas as pd
import pytest
import xarray as xr
from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as hst
from numpy.testing import assert_allclose
from pytest import LogCaptureFixture, TempPathFactory

Expand Down Expand Up @@ -1187,8 +1189,7 @@ def test_inverted_coords_perserved_on_netcdf_roundtrip(
file_path = os.path.join(path, expected_path)
ds = load_from_netcdf(file_path)

with pytest.warns(UserWarning):
xr_ds_reimported = ds.to_xarray_dataset()
xr_ds_reimported = ds.to_xarray_dataset()

assert xr_ds_reimported["z1"].dims == ("x", "y")
assert xr_ds_reimported["z2"].dims == ("y", "x")
Expand Down Expand Up @@ -1362,3 +1363,48 @@ def test_geneate_pandas_index():
pdi = _generate_pandas_index(indexes)
assert isinstance(pdi, pd.MultiIndex)
assert len(pdi) == 3


@given(
function_name=hst.sampled_from(
[
"to_xarray_dataarray_dict",
"to_pandas_dataframe",
"to_pandas_dataframe_dict",
"get_parameter_data",
]
)
)
@settings(suppress_health_check=(HealthCheck.function_scoped_fixture,), deadline=None)
def test_export_lazy_load(
tmp_path_factory: TempPathFactory, mock_dataset_grid: DataSet, function_name: str
) -> None:
tmp_path = tmp_path_factory.mktemp("export_netcdf")
path = str(tmp_path)
mock_dataset_grid.export(export_type="netcdf", path=tmp_path, prefix="qcodes_")

xr_ds = mock_dataset_grid.to_xarray_dataset()
assert xr_ds["z"].dims == ("x", "y")

expected_path = (
f"qcodes_{mock_dataset_grid.captured_run_id}_{mock_dataset_grid.guid}.nc"
)
assert os.listdir(path) == [expected_path]
file_path = os.path.join(path, expected_path)
ds = load_from_netcdf(file_path)

# loading the dataset should not load the actual data into cache
assert ds.cache._data == {}
# loading directly into xarray should not round
# trip to qcodes format and therefor not fill the cache
xr_ds_reimported = ds.to_xarray_dataset()
assert ds.cache._data == {}

assert xr_ds_reimported["z"].dims == ("x", "y")
assert xr_ds.identical(xr_ds_reimported)

# but loading with any of these functions
# will currently fill the cache
getattr(ds, function_name)()

assert ds.cache._data != {}
Loading