-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
save compressed netcdf #21
Changes from 32 commits
4d20a90
8e713b8
932f453
8f78ed1
c8fe94b
59da477
1f812e5
1444bde
f18f4f3
9561c93
c14f2e7
7ffedfb
951fbef
5ef9a11
84e6fc9
bc18481
f80c6a6
83068fe
d04b814
7aab6d3
8f2b880
3203681
0733266
ad6b82c
99791e7
2993ee5
f418235
4d0745c
4ec2215
8450be6
304a3a2
552e989
91d8ada
12f1957
85809f5
351ba9e
639216d
9aec307
41426d0
383f578
fa9c557
4dcb4da
bb879a4
a6aa9a4
e0bccc1
d9ca167
c64db0c
988741e
59df258
ccf4c1c
ed161bd
799f3e4
fd9bebe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,9 @@ | ||
"""Top-level package for bitinformation_pipeline.""" | ||
|
||
__author__ = """Hauke Schulz""" | ||
__email__ = "hauke.schulz@mpimet.mpg.de" | ||
__version__ = "0.0.1" | ||
|
||
from .bitinformation_pipeline import ( | ||
get_bitinformation, | ||
get_keepbits, | ||
plot_bitinformation, | ||
) | ||
from .bitround import xr_bitround | ||
from .save_compressed import get_compress_encoding |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import xarray as xr | ||
from numcodecs.bitround import BitRound | ||
|
||
|
||
def bitround(data, keepbits): | ||
"""Bitround for Arrays.""" | ||
codec = BitRound(keepbits=keepbits) | ||
data = data.copy() # otherwise overwrites the input | ||
encoded = codec.encode(data) | ||
return codec.decode(encoded) | ||
|
||
|
||
def xr_bitround(ds, keepbits): | ||
"""Apply bitrounding based on keepbits from bp.get_keepbits for xarray.Dataset or xr.DataArray. | ||
|
||
Inputs | ||
------ | ||
ds : xr.Dataset | ||
input netcdf to bitround | ||
keepbits : int or dict | ||
how many mantissa bits to keep | ||
|
||
Returns | ||
------- | ||
ds_bitrounded : xr.Dataset | ||
|
||
Example | ||
------- | ||
>>> ds = xr.tutorial.load_dataset("rasm") | ||
>>> info_per_bit = bp.get_bitinformation(ds, dim="x") | ||
>>> keepbits = bp.get_keepbits(ds, info_per_bit, 0.99) | ||
>>> ds_bitrounded = xr_bitround(ds, keepbits) | ||
""" | ||
ds_bitrounded = ds.copy() | ||
if isinstance(ds, xr.Dataset): | ||
for v in ds.data_vars: | ||
if ( | ||
ds[v].dtype == "float64" | ||
): # fails otherwise see https://github.com/zarr-developers/numcodecs/blob/7c7dc7cc83db1ae5c9fd93ece863acedbbc8156f/numcodecs/bitround.py#L23 | ||
ds[v] = ds[v].astype("float32") | ||
if isinstance(keepbits, int): | ||
keep = keepbits | ||
elif isinstance(keepbits, dict): | ||
if v in keepbits.keys(): | ||
keep = keepbits[v] | ||
else: | ||
continue | ||
aaronspring marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# fails for .data | ||
ds_bitrounded[v].values = bitround(ds[v].values, keep) | ||
ds_bitrounded[v].attrs["bitround_keepbits"] = keep | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could also add an attribute to the whole dataset linking to bitinformation.jl and bitinformation_pipeline There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In Zarr this information is written as a "filters": [
{
"id": "bitround",
"keepbits": 2
}
], This has however the short-coming, that an encoder is requested when accessing the data, although the encoder for bit round is straightforward and just increases the burden for others to read the data if they don't have the filter installed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does _QuantizeBitRoundNumberOfSignificantDigits refer to bits or mantissa bits? I think in nco bits is used by in numcodecs mantissa bits is used. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I now use bits not mantissa bits |
||
elif isinstance(ds, xr.DataArray): | ||
if isinstance(keepbits, int): | ||
keep = keepbits | ||
elif isinstance(keepbits, dict): | ||
v = ds.name | ||
if v in keepbits.keys(): | ||
keep = keepbits[v] | ||
else: | ||
raise ValueError("name not for in keepbits:", keepbits.keys()) | ||
aaronspring marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ds_bitrounded.data = bitround(ds.data, keep) | ||
ds_bitrounded.attrs["bitround_keepbits"] = keep | ||
return ds_bitrounded |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import xarray as xr | ||
from dask import is_dask_collection | ||
|
||
|
||
def get_chunksizes(da, for_cdo=False, time_dim="time", chunks=None): | ||
"""Get chunksizes for xr.DataArray for to_netcdf(encoding) from original file. | ||
If for_cdo, ensure time chunksize of 1 when compressed.""" | ||
assert isinstance(da, xr.DataArray) | ||
if chunks: # use new chunksizes | ||
return da.chunk(chunks).data.chunksize | ||
if for_cdo: # take shape as chunksize and ensure time chunksize 1 | ||
time_axis_num = da.get_axis_num(time_dim) | ||
chunksize = da.data.chunksize if is_dask_collection(da) else da.shape | ||
# https://code.mpimet.mpg.de/boards/2/topics/12598 | ||
chunksize = list(chunksize) | ||
chunksize[time_axis_num] = 1 | ||
chunksize = tuple(chunksize) | ||
return chunksize | ||
else: | ||
if is_dask_collection(da): | ||
return da.data.chunksize | ||
else: | ||
return da.shape | ||
|
||
|
||
def get_compress_encoding( | ||
ds_bitrounded, | ||
shuffle=True, | ||
complevel=9, | ||
for_cdo=False, | ||
time_dim="time", | ||
chunks=None, | ||
): | ||
"""Generate encoding for ds_bitrounded.to_netcdf(encoding). | ||
|
||
Example: | ||
>>> ds = xr.tutorial.load_dataset("rasm") | ||
>>> get_compress_encoding(ds) | ||
{'Tair': {'zlib': True, 'shuffle': True, 'complevel': 9, 'chunksizes': (36, 205, 275)}} | ||
>>> get_compress_encoding(ds, for_cdo=True) | ||
{'Tair': {'zlib': True, 'shuffle': True, 'complevel': 9, 'chunksizes': (1, 205, 275)}} | ||
|
||
""" | ||
return { | ||
v: { | ||
"zlib": True, | ||
"shuffle": shuffle, | ||
"complevel": complevel, | ||
"chunksizes": get_chunksizes( | ||
ds_bitrounded[v], for_cdo=for_cdo, time_dim=time_dim, chunks=chunks | ||
), | ||
} | ||
for v in ds_bitrounded.data_vars | ||
} | ||
|
||
|
||
@xr.register_dataset_accessor("to_compressed_netcdf") | ||
observingClouds marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class ToCompressed_Netcdf: | ||
"""Save to compressed netcdf wrapper. | ||
|
||
Example: | ||
>>> ds = xr.tutorial.load_dataset("rasm") | ||
>>> path = "rasm.nc" | ||
>>> ds.to_compressed_netcdf(path) | ||
>>> ds.to_compressed_netcdf(path, complevel=4) | ||
>>> ds.to_compressed_netcdf(path, for_cdo=True) | ||
|
||
""" | ||
|
||
def __init__(self, xarray_obj): | ||
self._obj = xarray_obj | ||
|
||
def __call__( | ||
self, | ||
path, | ||
shuffle=True, | ||
complevel=9, | ||
for_cdo=False, | ||
time_dim="time", | ||
chunks=None, | ||
): | ||
self._obj.to_netcdf( | ||
path, | ||
encoding=get_compress_encoding( | ||
self._obj, | ||
shuffle=shuffle, | ||
complevel=complevel, | ||
for_cdo=for_cdo, | ||
time_dim=time_dim, | ||
chunks=chunks, | ||
), | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import numpy as np | ||
import pytest | ||
import xarray as xr | ||
from xarray.tutorial import load_dataset | ||
|
||
import bitinformation_pipeline as bp | ||
|
||
xr.set_options(display_style="text") | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def add_standard_imports( | ||
doctest_namespace, | ||
): | ||
"""imports for doctest""" | ||
xr.set_options(display_style="text") | ||
doctest_namespace["np"] = np | ||
doctest_namespace["xr"] = xr | ||
doctest_namespace["bp"] = bp | ||
# always seed numpy.random to make the examples deterministic | ||
np.random.seed(42) | ||
|
||
|
||
@pytest.fixture() | ||
def rasm(): | ||
return load_dataset("rasm") | ||
|
||
|
||
@pytest.fixture() | ||
def air_temperature(): | ||
return load_dataset("air_temperature") | ||
|
||
|
||
@pytest.fixture() | ||
def ROMS_example(): | ||
return load_dataset("ROMS_example") | ||
|
||
|
||
@pytest.fixture() | ||
def era52mt(): | ||
return load_dataset("era5-2mt-2019-03-uk.grib") | ||
|
||
|
||
@pytest.fixture() | ||
def eraint_uvz(): | ||
return load_dataset("eraint_uvz") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you have an opinion here @observingClouds
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thats a disadvantage of the python version. it seems to be designed for
float32
primarily, which is thedtype
we deal with most often IMO