Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save compressed netcdf #21

Merged
merged 53 commits into from
Apr 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
4d20a90
initial
Apr 5, 2022
8e713b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 5, 2022
932f453
Delete setup_levante.md
aaronspring Apr 5, 2022
8f78ed1
get_encoding
Apr 6, 2022
c8fe94b
to_compressed_netcdf
Apr 6, 2022
59da477
Merge branch 'save_compressed_netcdf' of https://github.com/observing…
Apr 6, 2022
1f812e5
Delete save_compressed.py
aaronspring Apr 6, 2022
1444bde
Delete setup_levante.md
aaronspring Apr 6, 2022
f18f4f3
Merge branch 'main' into save_compressed_netcdf
aaronspring Apr 6, 2022
9561c93
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 6, 2022
c14f2e7
Update __init__.py
aaronspring Apr 6, 2022
7ffedfb
add numcodecs to setup.py
Apr 6, 2022
951fbef
Merge branch 'save_compressed_netcdf' of https://github.com/observing…
Apr 6, 2022
5ef9a11
fix init
Apr 6, 2022
84e6fc9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 6, 2022
bc18481
Merge branch 'main' into save_compressed_netcdf
aaronspring Apr 6, 2022
f80c6a6
fix
Apr 6, 2022
83068fe
Merge branch 'save_compressed_netcdf' of https://github.com/observing…
Apr 6, 2022
d04b814
Merge branch 'main' into save_compressed_netcdf
aaronspring Apr 6, 2022
7aab6d3
fix _time_dim
Apr 6, 2022
8f2b880
Merge branch 'save_compressed_netcdf' of https://github.com/observing…
Apr 6, 2022
3203681
fix path.getsize
Apr 6, 2022
0733266
add dask to setup.py
Apr 6, 2022
ad6b82c
requirements.txt
Apr 6, 2022
99791e7
test bitround test get_compress_encoding
Apr 6, 2022
2993ee5
doctests conftest
Apr 6, 2022
f418235
fix tests
Apr 6, 2022
4d0745c
add doctests
Apr 6, 2022
4ec2215
fix tests
Apr 6, 2022
8450be6
fix tests
Apr 6, 2022
304a3a2
fix tests
Apr 6, 2022
552e989
update notebook
aaronspring Apr 6, 2022
91d8ada
add compression as keyword
Apr 6, 2022
12f1957
Merge branch 'save_compressed_netcdf' of https://github.com/observing…
Apr 6, 2022
85809f5
keep keepbits-9 in xr_bitround
Apr 6, 2022
351ba9e
fix bitround keep
Apr 6, 2022
639216d
fix bitround keep
Apr 6, 2022
9aec307
fix test?
Apr 6, 2022
41426d0
fix test
Apr 6, 2022
383f578
Apply suggestions from code review
aaronspring Apr 6, 2022
fa9c557
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 6, 2022
4dcb4da
refactor xr_bitround
Apr 7, 2022
bb879a4
refactor xr_bitround and merge
Apr 7, 2022
a6aa9a4
fix
Apr 7, 2022
e0bccc1
fix
Apr 7, 2022
d9ca167
fix
Apr 7, 2022
c64db0c
fix
Apr 7, 2022
988741e
keep rasm float64 and xr_bitround only for float32
Apr 7, 2022
59df258
test_xr_bitround on ds and da
Apr 7, 2022
ccf4c1c
fix tests
Apr 7, 2022
ed161bd
fix tests
Apr 7, 2022
799f3e4
fix tests
Apr 7, 2022
fd9bebe
fix tests
Apr 7, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,35 @@ jobs:
- name: Run tests
run: pytest

doctest: # tests all docstrings
name: doctests
runs-on: "ubuntu-latest"
defaults:
run:
shell: bash -l {0}
steps:
- uses: actions/checkout@v3
- uses: conda-incubator/setup-miniconda@v2
with:
channels: conda-forge
mamba-version: "*"
activate-environment: bitinfo
auto-update-conda: false
python-version: 3.8
- name: Install conda dependencies
run: |
mamba env update -f environment.yml
- name: Install bitinformation_pipeline
run: |
python -m pip install -e .
- name: Conda info
run: conda info
- name: Conda list
run: conda list
- name: Run doctests
run: |
python -m pytest --doctest-modules bitinformation_pipeline --ignore bitinformation_pipeline/tests

test_notebooks: # Checks that notebooks complete without failure
name: "notebooks"
runs-on: ubuntu-latest
Expand Down Expand Up @@ -109,5 +138,6 @@ jobs:
python-version: 3.8
- name: Install dependencies
run: |
pip install git+https://github.com/observingClouds/numcodecs@bitround # remove once merged
pip install .
python -c "import bitinformation_pipeline"
6 changes: 2 additions & 4 deletions bitinformation_pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""Top-level package for bitinformation_pipeline."""

__author__ = """Hauke Schulz"""
__email__ = "hauke.schulz@mpimet.mpg.de"
__version__ = "0.0.1"

from .bitinformation_pipeline import (
get_bitinformation,
get_keepbits,
plot_bitinformation,
)
from .bitround import xr_bitround
from .save_compressed import get_compress_encoding
26 changes: 25 additions & 1 deletion bitinformation_pipeline/bitinformation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@ def get_bitinformation(ds, label=None, overwrite=False, **kwargs):
-------
info_per_bit : dict
Information content per bit and variable

Example
-------
>>> ds = xr.tutorial.load_dataset("air_temperature")
>>> bp.get_bitinformation(ds, dim="lon")
{'air': array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 3.94447851e-01, 3.94447851e-01, 3.94447851e-01,
3.94447851e-01, 3.94447851e-01, 3.94310542e-01, 7.36739987e-01,
5.62682836e-01, 3.60511555e-01, 1.52471111e-01, 4.18818055e-02,
3.65276146e-03, 1.19975820e-05, 4.39366160e-05, 4.18329296e-05,
2.54572089e-05, 1.44121797e-04, 1.34144798e-03, 1.55468479e-06,
5.38601212e-04, 8.09862581e-04, 1.74893445e-04, 4.97915410e-05,
3.88027711e-04, 0.00000000e+00, 3.95323228e-05, 6.88854435e-04])}
"""
if label is not None and overwrite is False:
info_per_bit = load_bitinformation(label)
Expand Down Expand Up @@ -100,7 +113,8 @@ def load_bitinformation(label):


def get_keepbits(ds, info_per_bit, inflevel=0.99):
"""Get the amount of bits to keep for a given information content
"""Get the amount of bits to keep for a given information content.

Inputs
------
ds : xr.Dataset
Expand All @@ -110,10 +124,20 @@ def get_keepbits(ds, info_per_bit, inflevel=0.99):
inflevel : float or dict
Level of information that shall be preserved. Of type `float` if the
preserved information content should be equal across variables, otherwise of type `dict`.

Returns
-------
keepbits : dict
Number of bits to keep per variable

Example
-------
>>> ds = xr.tutorial.load_dataset("air_temperature")
>>> info_per_bit = bp.get_bitinformation(ds, dim="lon")
>>> bp.get_keepbits(ds, info_per_bit)
{'air': 16}
>>> bp.get_keepbits(ds, info_per_bit, inflevel=0.99999999)
{'air': 23}
"""

def get_inflevel(var, inflevel):
Expand Down
57 changes: 57 additions & 0 deletions bitinformation_pipeline/bitround.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import xarray as xr
from numcodecs.bitround import BitRound


def bitround(data, keepbits):
"""Bitround for Arrays."""
codec = BitRound(keepbits=keepbits)
data = data.copy() # otherwise overwrites the input
encoded = codec.encode(data)
return codec.decode(encoded)


def xr_bitround(da, keepbits):
"""Apply bitrounding based on keepbits from bp.get_keepbits for xarray.Dataset or xr.DataArray.

Inputs
------
da : xr.DataArray or xr.Dataset
input netcdf to bitround with dtype float32
keepbits : int or dict of {str: int}
how many bits to keep. int

Returns
-------
da_bitrounded : xr.DataArray or xr.Dataset

Example
-------
>>> ds = xr.tutorial.load_dataset("air_temperature")
>>> info_per_bit = bp.get_bitinformation(ds, dim="lon")
>>> keepbits = bp.get_keepbits(ds, info_per_bit, 0.99)
>>> ds_bitrounded = xr_bitround(ds, keepbits)
"""
if isinstance(da, xr.Dataset):
da_bitrounded = da.copy()
for v in da.data_vars:
da_bitrounded[v] = xr_bitround(da[v], keepbits)
return da_bitrounded

assert da.dtype == "float32"
da_bitrounded = da.copy()
if isinstance(keepbits, int):
keep = keepbits
elif isinstance(keepbits, dict):
v = da.name
if v in keepbits.keys():
keep = keepbits[v]
else:
raise ValueError(f"name {v} not for in keepbits: {keepbits.keys()}")
# fails for .data
da_bitrounded.values = bitround(
da.values, keep - 9
) # uses keep mantissa bits: [0-23]
da_bitrounded.attrs[
"_QuantizeBitRoundNumberOfSignificantDigits"
] = keep # document keepbits
return da_bitrounded
95 changes: 95 additions & 0 deletions bitinformation_pipeline/save_compressed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import xarray as xr
from dask import is_dask_collection


def get_chunksizes(da, for_cdo=False, time_dim="time", chunks=None):
"""Get chunksizes for xr.DataArray for to_netcdf(encoding) from original file.
If for_cdo, ensure time chunksize of 1 when compressed."""
assert isinstance(da, xr.DataArray)
if chunks: # use new chunksizes
return da.chunk(chunks).data.chunksize
if for_cdo: # take shape as chunksize and ensure time chunksize 1
time_axis_num = da.get_axis_num(time_dim)
chunksize = da.data.chunksize if is_dask_collection(da) else da.shape
# https://code.mpimet.mpg.de/boards/2/topics/12598
chunksize = list(chunksize)
chunksize[time_axis_num] = 1
chunksize = tuple(chunksize)
return chunksize
else:
if is_dask_collection(da):
return da.data.chunksize
else:
return da.shape


def get_compress_encoding(
ds_bitrounded,
compression="zlib",
shuffle=True,
complevel=9,
for_cdo=False,
time_dim="time",
chunks=None,
):
"""Generate encoding for ds_bitrounded.to_netcdf(encoding).

Example:
>>> ds = xr.tutorial.load_dataset("rasm")
>>> get_compress_encoding(ds)
{'Tair': {'zlib': True, 'shuffle': True, 'complevel': 9, 'chunksizes': (36, 205, 275)}}
>>> get_compress_encoding(ds, for_cdo=True)
{'Tair': {'zlib': True, 'shuffle': True, 'complevel': 9, 'chunksizes': (1, 205, 275)}}

"""
return {
v: {
compression: True,
"shuffle": shuffle,
"complevel": complevel,
"chunksizes": get_chunksizes(
ds_bitrounded[v], for_cdo=for_cdo, time_dim=time_dim, chunks=chunks
),
}
for v in ds_bitrounded.data_vars
}


@xr.register_dataset_accessor("to_compressed_netcdf")
observingClouds marked this conversation as resolved.
Show resolved Hide resolved
class ToCompressed_Netcdf:
"""Save to compressed netcdf wrapper.

Example:
>>> ds = xr.tutorial.load_dataset("rasm")
>>> path = "rasm.nc"
>>> ds.to_compressed_netcdf(path)
>>> ds.to_compressed_netcdf(path, complevel=4)
>>> ds.to_compressed_netcdf(path, for_cdo=True)

"""

def __init__(self, xarray_obj):
self._obj = xarray_obj

def __call__(
self,
path,
compression="zlib",
shuffle=True,
complevel=9,
for_cdo=False,
time_dim="time",
chunks=None,
):
self._obj.to_netcdf(
path,
encoding=get_compress_encoding(
self._obj,
compression=compression,
shuffle=shuffle,
complevel=complevel,
for_cdo=for_cdo,
time_dim=time_dim,
chunks=chunks,
),
)
51 changes: 51 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np
import pytest
import xarray as xr
from xarray.tutorial import load_dataset

import bitinformation_pipeline as bp

xr.set_options(display_style="text")


@pytest.fixture(autouse=True)
def add_standard_imports(
doctest_namespace,
):
"""imports for doctest"""
xr.set_options(display_style="text")
doctest_namespace["np"] = np
doctest_namespace["xr"] = xr
doctest_namespace["bp"] = bp
# always seed numpy.random to make the examples deterministic
np.random.seed(42)


@pytest.fixture()
def rasm():
"""one atmospheric variable float64 with masked ocean"""
return load_dataset("rasm")


@pytest.fixture()
def air_temperature():
"""one atmospheric variable float32 over the US no mask"""
return load_dataset("air_temperature")


@pytest.fixture()
def ROMS_example():
"""two ocean variabls float32 with masked land"""
return load_dataset("ROMS_example")


@pytest.fixture()
def era52mt():
"""one variable float32 t2m over the UK no mask"""
return load_dataset("era5-2mt-2019-03-uk.grib")


@pytest.fixture()
def eraint_uvz():
"""three atmospheric variable float32 global no mask"""
return load_dataset("eraint_uvz")
Loading