Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save compressed netcdf #21

Merged
merged 53 commits into from
Apr 7, 2022
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
4d20a90
initial
Apr 5, 2022
8e713b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 5, 2022
932f453
Delete setup_levante.md
aaronspring Apr 5, 2022
8f78ed1
get_encoding
Apr 6, 2022
c8fe94b
to_compressed_netcdf
Apr 6, 2022
59da477
Merge branch 'save_compressed_netcdf' of https://github.com/observing…
Apr 6, 2022
1f812e5
Delete save_compressed.py
aaronspring Apr 6, 2022
1444bde
Delete setup_levante.md
aaronspring Apr 6, 2022
f18f4f3
Merge branch 'main' into save_compressed_netcdf
aaronspring Apr 6, 2022
9561c93
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 6, 2022
c14f2e7
Update __init__.py
aaronspring Apr 6, 2022
7ffedfb
add numcodecs to setup.py
Apr 6, 2022
951fbef
Merge branch 'save_compressed_netcdf' of https://github.com/observing…
Apr 6, 2022
5ef9a11
fix init
Apr 6, 2022
84e6fc9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 6, 2022
bc18481
Merge branch 'main' into save_compressed_netcdf
aaronspring Apr 6, 2022
f80c6a6
fix
Apr 6, 2022
83068fe
Merge branch 'save_compressed_netcdf' of https://github.com/observing…
Apr 6, 2022
d04b814
Merge branch 'main' into save_compressed_netcdf
aaronspring Apr 6, 2022
7aab6d3
fix _time_dim
Apr 6, 2022
8f2b880
Merge branch 'save_compressed_netcdf' of https://github.com/observing…
Apr 6, 2022
3203681
fix path.getsize
Apr 6, 2022
0733266
add dask to setup.py
Apr 6, 2022
ad6b82c
requirements.txt
Apr 6, 2022
99791e7
test bitround test get_compress_encoding
Apr 6, 2022
2993ee5
doctests conftest
Apr 6, 2022
f418235
fix tests
Apr 6, 2022
4d0745c
add doctests
Apr 6, 2022
4ec2215
fix tests
Apr 6, 2022
8450be6
fix tests
Apr 6, 2022
304a3a2
fix tests
Apr 6, 2022
552e989
update notebook
aaronspring Apr 6, 2022
91d8ada
add compression as keyword
Apr 6, 2022
12f1957
Merge branch 'save_compressed_netcdf' of https://github.com/observing…
Apr 6, 2022
85809f5
keep keepbits-9 in xr_bitround
Apr 6, 2022
351ba9e
fix bitround keep
Apr 6, 2022
639216d
fix bitround keep
Apr 6, 2022
9aec307
fix test?
Apr 6, 2022
41426d0
fix test
Apr 6, 2022
383f578
Apply suggestions from code review
aaronspring Apr 6, 2022
fa9c557
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 6, 2022
4dcb4da
refactor xr_bitround
Apr 7, 2022
bb879a4
refactor xr_bitround and merge
Apr 7, 2022
a6aa9a4
fix
Apr 7, 2022
e0bccc1
fix
Apr 7, 2022
d9ca167
fix
Apr 7, 2022
c64db0c
fix
Apr 7, 2022
988741e
keep rasm float64 and xr_bitround only for float32
Apr 7, 2022
59df258
test_xr_bitround on ds and da
Apr 7, 2022
ccf4c1c
fix tests
Apr 7, 2022
ed161bd
fix tests
Apr 7, 2022
799f3e4
fix tests
Apr 7, 2022
fd9bebe
fix tests
Apr 7, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,35 @@ jobs:
- name: Run tests
run: pytest

doctest: # tests all docstrings
name: doctests
runs-on: "ubuntu-latest"
defaults:
run:
shell: bash -l {0}
steps:
- uses: actions/checkout@v3
- uses: conda-incubator/setup-miniconda@v2
with:
channels: conda-forge
mamba-version: "*"
activate-environment: bitinfo
auto-update-conda: false
python-version: 3.8
- name: Install conda dependencies
run: |
mamba env update -f environment.yml
- name: Install bitinformation_pipeline
run: |
python -m pip install -e .
- name: Conda info
run: conda info
- name: Conda list
run: conda list
- name: Run doctests
run: |
python -m pytest --doctest-modules bitinformation_pipeline --ignore bitinformation_pipeline/tests

test_notebooks: # Checks that notebooks complete without failure
name: "notebooks"
runs-on: ubuntu-latest
Expand Down Expand Up @@ -109,5 +138,6 @@ jobs:
python-version: 3.8
- name: Install dependencies
run: |
pip install git+https://github.com/observingClouds/numcodecs@bitround # remove once merged
pip install .
python -c "import bitinformation_pipeline"
6 changes: 2 additions & 4 deletions bitinformation_pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""Top-level package for bitinformation_pipeline."""

__author__ = """Hauke Schulz"""
__email__ = "hauke.schulz@mpimet.mpg.de"
__version__ = "0.0.1"

from .bitinformation_pipeline import (
get_bitinformation,
get_keepbits,
plot_bitinformation,
)
from .bitround import xr_bitround
from .save_compressed import get_compress_encoding
35 changes: 34 additions & 1 deletion bitinformation_pipeline/bitinformation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,28 @@ def get_bitinformation(ds, label=None, overwrite=False, **kwargs):
-------
info_per_bit : dict
Information content per bit and variable

Example
-------
>>> ds = xr.tutorial.load_dataset("rasm")
>>> bp.get_bitinformation(ds, dim="x")
{'Tair': array([6.28759085e-01, 7.37993809e-01, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
4.10064704e-06, 4.75985410e-01, 5.20519662e-01, 3.93991763e-01,
3.63939048e-01, 2.00036924e-01, 1.31092761e-01, 8.93712601e-02,
7.16473010e-02, 6.84045398e-02, 6.87513712e-02, 6.89925122e-02,
6.90843796e-02, 6.90237233e-02, 6.95050692e-02, 6.94720711e-02,
6.94601453e-02, 6.96058765e-02, 6.98843472e-02, 6.92556384e-02,
6.84707129e-02, 6.91220148e-02, 6.93995066e-02, 6.92542336e-02,
6.88630993e-02, 6.88312736e-02, 6.89656830e-02, 6.93957020e-02,
6.85981736e-02, 6.97206990e-02, 6.96303301e-02, 6.89981939e-02,
7.03003113e-02, 6.96626582e-02, 6.94376911e-02, 6.91778910e-02,
6.93997653e-02, 7.01042669e-02, 6.96544993e-02, 6.92199298e-02,
6.97360327e-02, 6.95376714e-02, 6.97447985e-02, 6.95418140e-02,
6.96346655e-02, 6.97496057e-02, 6.95058114e-02, 6.93239423e-02,
6.89041586e-02, 6.95802295e-02, 6.96424276e-02, 6.56236800e-02,
6.95009315e-02, 7.67737099e-02, 8.06336563e-02, 8.08729657e-02])}

"""
if label is not None and overwrite is False:
info_per_bit = load_bitinformation(label)
Expand Down Expand Up @@ -100,7 +122,8 @@ def load_bitinformation(label):


def get_keepbits(ds, info_per_bit, inflevel=0.99):
"""Get the amount of bits to keep for a given information content
"""Get the amount of bits to keep for a given information content.

Inputs
------
ds : xr.Dataset
Expand All @@ -110,10 +133,20 @@ def get_keepbits(ds, info_per_bit, inflevel=0.99):
inflevel : float or dict
Level of information that shall be preserved. Of type `float` if the
preserved information content should be equal across variables, otherwise of type `dict`.

Returns
-------
keepbits : dict
Number of bits to keep per variable

Example
-------
>>> ds = xr.tutorial.load_dataset("rasm")
>>> info_per_bit = bp.get_bitinformation(ds, dim="x")
>>> bp.get_keepbits(ds, info_per_bit)
{'Tair': 15}
>>> bp.get_keepbits(ds, info_per_bit, inflevel=0.99999999)
{'Tair': 15}
"""

def get_inflevel(var, inflevel):
Expand Down
62 changes: 62 additions & 0 deletions bitinformation_pipeline/bitround.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import xarray as xr
from numcodecs.bitround import BitRound


def bitround(data, keepbits):
"""Bitround for Arrays."""
codec = BitRound(keepbits=keepbits)
data = data.copy() # otherwise overwrites the input
encoded = codec.encode(data)
return codec.decode(encoded)


def xr_bitround(ds, keepbits):
"""Apply bitrounding based on keepbits from bp.get_keepbits for xarray.Dataset or xr.DataArray.

Inputs
------
ds : xr.Dataset
input netcdf to bitround
keepbits : int or dict
how many mantissa bits to keep

Returns
-------
ds_bitrounded : xr.Dataset

Example
-------
>>> ds = xr.tutorial.load_dataset("rasm")
>>> info_per_bit = bp.get_bitinformation(ds, dim="x")
>>> keepbits = bp.get_keepbits(ds, info_per_bit, 0.99)
>>> ds_bitrounded = xr_bitround(ds, keepbits)
"""
ds_bitrounded = ds.copy()
if isinstance(ds, xr.Dataset):
for v in ds.data_vars:
if (
ds[v].dtype == "float64"
): # fails otherwise see https://github.com/zarr-developers/numcodecs/blob/7c7dc7cc83db1ae5c9fd93ece863acedbbc8156f/numcodecs/bitround.py#L23
ds[v] = ds[v].astype("float32")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have an opinion here @observingClouds

Copy link
Collaborator Author

@aaronspring aaronspring Apr 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thats a disadvantage of the python version. it seems to be designed for float32 primarily, which is the dtype we deal with most often IMO

if isinstance(keepbits, int):
keep = keepbits
elif isinstance(keepbits, dict):
if v in keepbits.keys():
keep = keepbits[v]
else:
continue
aaronspring marked this conversation as resolved.
Show resolved Hide resolved
# fails for .data
ds_bitrounded[v].values = bitround(ds[v].values, keep)
ds_bitrounded[v].attrs["bitround_keepbits"] = keep
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could also add an attribute to the whole dataset linking to bitinformation.jl and bitinformation_pipeline

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Zarr this information is written as a filters attribute, e.g.

"filters": [
        {
            "id": "bitround",
            "keepbits": 2
        }
    ],

This has however the short-coming, that an encoder is requested when accessing the data, although the encoder for bit round is straightforward and just increases the burden for others to read the data if they don't have the filter installed.
The CF Conventions, propose:
image
but that does not really help for a variable dependent filter...
I didn't completely follow the discussion here and know if they implemented the same technique we are using here. I get confused by BitRounding, BitGrooming, etc. @milankl can maybe clarify this.
But if it they use the same technique, we should probably use _QuantizeBitRoundNumberOfSignificantDigits as attribute reference

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does _QuantizeBitRoundNumberOfSignificantDigits refer to bits or mantissa bits? I think in nco bits is used by in numcodecs mantissa bits is used.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now use bits not mantissa bits

elif isinstance(ds, xr.DataArray):
if isinstance(keepbits, int):
keep = keepbits
elif isinstance(keepbits, dict):
v = ds.name
if v in keepbits.keys():
keep = keepbits[v]
else:
raise ValueError("name not for in keepbits:", keepbits.keys())
aaronspring marked this conversation as resolved.
Show resolved Hide resolved
ds_bitrounded.data = bitround(ds.data, keep)
ds_bitrounded.attrs["bitround_keepbits"] = keep
return ds_bitrounded
92 changes: 92 additions & 0 deletions bitinformation_pipeline/save_compressed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import xarray as xr
from dask import is_dask_collection


def get_chunksizes(da, for_cdo=False, time_dim="time", chunks=None):
"""Get chunksizes for xr.DataArray for to_netcdf(encoding) from original file.
If for_cdo, ensure time chunksize of 1 when compressed."""
assert isinstance(da, xr.DataArray)
if chunks: # use new chunksizes
return da.chunk(chunks).data.chunksize
if for_cdo: # take shape as chunksize and ensure time chunksize 1
time_axis_num = da.get_axis_num(time_dim)
chunksize = da.data.chunksize if is_dask_collection(da) else da.shape
# https://code.mpimet.mpg.de/boards/2/topics/12598
chunksize = list(chunksize)
chunksize[time_axis_num] = 1
chunksize = tuple(chunksize)
return chunksize
else:
if is_dask_collection(da):
return da.data.chunksize
else:
return da.shape


def get_compress_encoding(
ds_bitrounded,
shuffle=True,
complevel=9,
for_cdo=False,
time_dim="time",
chunks=None,
):
"""Generate encoding for ds_bitrounded.to_netcdf(encoding).

Example:
>>> ds = xr.tutorial.load_dataset("rasm")
>>> get_compress_encoding(ds)
{'Tair': {'zlib': True, 'shuffle': True, 'complevel': 9, 'chunksizes': (36, 205, 275)}}
>>> get_compress_encoding(ds, for_cdo=True)
{'Tair': {'zlib': True, 'shuffle': True, 'complevel': 9, 'chunksizes': (1, 205, 275)}}

"""
return {
v: {
"zlib": True,
"shuffle": shuffle,
"complevel": complevel,
"chunksizes": get_chunksizes(
ds_bitrounded[v], for_cdo=for_cdo, time_dim=time_dim, chunks=chunks
),
}
for v in ds_bitrounded.data_vars
}


@xr.register_dataset_accessor("to_compressed_netcdf")
observingClouds marked this conversation as resolved.
Show resolved Hide resolved
class ToCompressed_Netcdf:
"""Save to compressed netcdf wrapper.

Example:
>>> ds = xr.tutorial.load_dataset("rasm")
>>> path = "rasm.nc"
>>> ds.to_compressed_netcdf(path)
>>> ds.to_compressed_netcdf(path, complevel=4)
>>> ds.to_compressed_netcdf(path, for_cdo=True)

"""

def __init__(self, xarray_obj):
self._obj = xarray_obj

def __call__(
self,
path,
shuffle=True,
complevel=9,
for_cdo=False,
time_dim="time",
chunks=None,
):
self._obj.to_netcdf(
path,
encoding=get_compress_encoding(
self._obj,
shuffle=shuffle,
complevel=complevel,
for_cdo=for_cdo,
time_dim=time_dim,
chunks=chunks,
),
)
46 changes: 46 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy as np
import pytest
import xarray as xr
from xarray.tutorial import load_dataset

import bitinformation_pipeline as bp

xr.set_options(display_style="text")


@pytest.fixture(autouse=True)
def add_standard_imports(
doctest_namespace,
):
"""imports for doctest"""
xr.set_options(display_style="text")
doctest_namespace["np"] = np
doctest_namespace["xr"] = xr
doctest_namespace["bp"] = bp
# always seed numpy.random to make the examples deterministic
np.random.seed(42)


@pytest.fixture()
def rasm():
return load_dataset("rasm")


@pytest.fixture()
def air_temperature():
return load_dataset("air_temperature")


@pytest.fixture()
def ROMS_example():
return load_dataset("ROMS_example")


@pytest.fixture()
def era52mt():
return load_dataset("era5-2mt-2019-03-uk.grib")


@pytest.fixture()
def eraint_uvz():
return load_dataset("eraint_uvz")
Loading