Skip to content

Commit

Permalink
Make SChunk.cparams and dparams to only accept dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
martaiborra committed Sep 20, 2024
1 parent b75ee5d commit e3f128f
Show file tree
Hide file tree
Showing 15 changed files with 124 additions and 136 deletions.
101 changes: 49 additions & 52 deletions src/blosc2/blosc2_ext.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ from enum import Enum

import numpy as np
from msgpack import packb, unpackb
from dataclasses import asdict

import blosc2

Expand Down Expand Up @@ -1069,17 +1070,6 @@ cdef class SChunk:
else:
# User codec
codec = self.schunk.storage.cparams.compcode
cparams_dict = {
"codec": codec,
"codec_meta": self.schunk.storage.cparams.compcode_meta,
"clevel": self.schunk.storage.cparams.clevel,
"use_dict": self.schunk.storage.cparams.use_dict,
"typesize": self.schunk.storage.cparams.typesize,
"nthreads": self.schunk.storage.cparams.nthreads,
"blocksize": self.schunk.storage.cparams.blocksize,
"splitmode": blosc2.SplitMode(self.schunk.storage.cparams.splitmode),
"tuner": blosc2.Tuner(self.schunk.storage.cparams.tuner_id),
}

filters = [0] * BLOSC2_MAX_FILTERS
filters_meta = [0] * BLOSC2_MAX_FILTERS
Expand All @@ -1090,42 +1080,50 @@ cdef class SChunk:
# User filter
filters[i] = self.schunk.filters[i]
filters_meta[i] = self.schunk.filters_meta[i]
cparams_dict["filters"] = filters
cparams_dict["filters_meta"] = filters_meta
return cparams_dict

def update_cparams(self, cparams_dict):
cparams = blosc2.CParams(
codec=codec,
codec_meta=self.schunk.storage.cparams.compcode_meta,
clevel=self.schunk.storage.cparams.clevel,
use_dict=bool(self.schunk.storage.cparams.use_dict),
typesize=self.schunk.storage.cparams.typesize,
nthreads=self.schunk.storage.cparams.nthreads,
blocksize=self.schunk.storage.cparams.blocksize,
splitmode=blosc2.SplitMode(self.schunk.storage.cparams.splitmode),
tuner=blosc2.Tuner(self.schunk.storage.cparams.tuner_id),
filters=filters,
filters_meta=filters_meta,
)

return cparams

def update_cparams(self, new_cparams):
cdef blosc2_cparams* cparams = self.schunk.storage.cparams
codec = cparams_dict.get('codec', None)
if codec is None:
cparams.compcode = cparams.compcode
else:
cparams.compcode = codec if not isinstance(codec, blosc2.Codec) else codec.value
cparams.compcode_meta = cparams_dict.get('codec_meta', cparams.compcode_meta)
cparams.clevel = cparams_dict.get('clevel', cparams.clevel)
cparams.use_dict = cparams_dict.get('use_dict', cparams.use_dict)
cparams.typesize = cparams_dict.get('typesize', cparams.typesize)
cparams.nthreads = cparams_dict.get('nthreads', cparams.nthreads)
cparams.blocksize = cparams_dict.get('blocksize', cparams.blocksize)
splitmode = cparams_dict.get('splitmode', None)
cparams.splitmode = cparams.splitmode if splitmode is None else splitmode.value

filters = cparams_dict.get('filters', None)
if filters is not None:
for i, filter in enumerate(filters):
cparams.filters[i] = filter.value if isinstance(filter, Enum) else filter
for i in range(len(filters), BLOSC2_MAX_FILTERS):
cparams.filters[i] = 0

filters_meta = cparams_dict.get('filters_meta', None)
codec = new_cparams.codec
cparams.compcode = codec if not isinstance(codec, blosc2.Codec) else codec.value
cparams.compcode_meta = new_cparams.codec_meta
cparams.clevel = new_cparams.clevel
cparams.use_dict = new_cparams.use_dict
cparams.typesize = new_cparams.typesize
cparams.nthreads = new_cparams.nthreads
cparams.blocksize = new_cparams.blocksize
cparams.splitmode = new_cparams.splitmode.value
cparams.tuner_id = new_cparams.tuner.value

filters = new_cparams.filters
for i, filter in enumerate(filters):
cparams.filters[i] = filter.value if isinstance(filter, Enum) else filter
for i in range(len(filters), BLOSC2_MAX_FILTERS):
cparams.filters[i] = 0

filters_meta = new_cparams.filters_meta
cdef int8_t meta_value
if filters_meta is not None:
for i, meta in enumerate(filters_meta):
# We still may want to encode negative values
meta_value = <int8_t> meta if meta < 0 else meta
cparams.filters_meta[i] = <uint8_t> meta_value
for i in range(len(filters_meta), BLOSC2_MAX_FILTERS):
cparams.filters_meta[i] = 0
for i, meta in enumerate(filters_meta):
# We still may want to encode negative values
meta_value = <int8_t> meta if meta < 0 else meta
cparams.filters_meta[i] = <uint8_t> meta_value
for i in range(len(filters_meta), BLOSC2_MAX_FILTERS):
cparams.filters_meta[i] = 0

_check_cparams(cparams)

Expand All @@ -1143,12 +1141,11 @@ cdef class SChunk:
self.schunk.filters_meta = self.schunk.storage.cparams.filters_meta

def get_dparams(self):
dparams_dict = {"nthreads": self.schunk.storage.dparams.nthreads}
return dparams_dict
return blosc2.DParams(nthreads=self.schunk.storage.dparams.nthreads)

def update_dparams(self, dparams_dict):
def update_dparams(self, new_dparams):
cdef blosc2_dparams* dparams = self.schunk.storage.dparams
dparams.nthreads = dparams_dict.get('nthreads', dparams.nthreads)
dparams.nthreads = new_dparams.nthreads

_check_dparams(dparams, self.schunk.storage.cparams)

Expand Down Expand Up @@ -1967,17 +1964,17 @@ def open(urlpath, mode, offset, **kwargs):
res = blosc2.NDArray(_schunk=PyCapsule_New(array.sc, <char *> "blosc2_schunk*", NULL),
_array=PyCapsule_New(array, <char *> "b2nd_array_t*", NULL))
if cparams is not None:
res.schunk.cparams = cparams
res.schunk.cparams = cparams if isinstance(cparams, blosc2.CParams) else blosc2.CParams(**cparams)
if dparams is not None:
res.schunk.dparams = dparams
res.schunk.dparams = dparams if isinstance(dparams, blosc2.DParams) else blosc2.DParams(**dparams)
res.schunk.mode = mode
else:
res = blosc2.SChunk(_schunk=PyCapsule_New(schunk, <char *> "blosc2_schunk*", NULL),
mode=mode, **kwargs)
if cparams is not None:
res.cparams = cparams
res.cparams = cparams if isinstance(cparams, blosc2.CParams) else blosc2.CParams(**cparams)
if dparams is not None:
res.dparams = dparams
res.dparams = dparams if isinstance(dparams, blosc2.DParams) else blosc2.DParams(**dparams)

return res

Expand Down
2 changes: 1 addition & 1 deletion src/blosc2/lazyexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1884,7 +1884,7 @@ def eval(self, item=None, **kwargs):
aux = np.empty(res_eval.shape, res_eval.dtype)
res_eval[...] = aux
res_eval.schunk.remove_prefilter(self.func.__name__)
res_eval.schunk.cparams["nthreads"] = self._cnthreads
res_eval.schunk.cparams.nthreads = self._cnthreads

return res_eval
else:
Expand Down
5 changes: 3 additions & 2 deletions src/blosc2/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import ndindex
import numpy as np
from dataclasses import asdict

import blosc2
from blosc2 import SpecialValue, blosc2_ext, compute_chunks_blocks
Expand Down Expand Up @@ -1288,8 +1289,8 @@ def copy(self, dtype: np.dtype = None, **kwargs: dict) -> NDArray:
"""
if dtype is None:
dtype = self.dtype
kwargs["cparams"] = kwargs.get("cparams", self.schunk.cparams).copy()
kwargs["dparams"] = kwargs.get("dparams", self.schunk.dparams).copy()
kwargs["cparams"] = kwargs.get("cparams", asdict(self.schunk.cparams)).copy()
kwargs["dparams"] = kwargs.get("dparams", asdict(self.schunk.dparams)).copy()
if "meta" not in kwargs:
# Copy metalayers as well
meta_dict = {meta: self.schunk.meta[meta] for meta in self.schunk.meta}
Expand Down
12 changes: 6 additions & 6 deletions src/blosc2/schunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,26 +279,26 @@ def __init__(self, chunksize: int = None, data: object = None, **kwargs: dict):
self._dparams = super().get_dparams()

@property
def cparams(self) -> dict:
def cparams(self) -> blosc2.CParams:
"""
Dictionary with the compression parameters.
:class:`blosc2.CParams` instance with the compression parameters.
"""
return self._cparams

@cparams.setter
def cparams(self, value):
def cparams(self, value: blosc2.CParams) -> None:
super().update_cparams(value)
self._cparams = super().get_cparams()

@property
def dparams(self) -> dict:
def dparams(self) -> blosc2.DParams:
"""
Dictionary with the decompression parameters.
:class:`blosc2.DParams` instance with the decompression parameters.
"""
return self._dparams

@dparams.setter
def dparams(self, value):
def dparams(self, value: blosc2.DParams) -> None:
super().update_dparams(value)
self._dparams = super().get_dparams()

Expand Down
10 changes: 8 additions & 2 deletions src/blosc2/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,14 @@ class CParams:
filters_meta: list[int] = field(default_factory=default_filters_meta)
tuner: blosc2.Tuner = blosc2.Tuner.STUNE

# def __post_init__(self):
# if len(self.filters) > 6:
def __post_init__(self):
if len(self.filters) > 6:
raise ValueError("Number of filters exceeds 6")
if len(self.filters) < len(self.filters_meta):
self.filters_meta = self.filters_meta[:len(self.filters)]
warnings.warn("Changed `filters_meta` length to match `filters` length")
if len(self.filters) > len(self.filters_meta):
raise ValueError("Number of filters cannot exceed number of filters meta")


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions tests/ndarray/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def test_copy(shape, chunks1, blocks1, chunks2, blocks2, dtype):
assert a.schunk.dparams == b.schunk.dparams
for key in cparams2:
if key in ("filters", "filters_meta"):
assert b.schunk.cparams[key][: len(cparams2[key])] == cparams2[key]
assert getattr(b.schunk.cparams, key)[: len(cparams2[key])] == cparams2[key]
continue
assert b.schunk.cparams[key] == cparams2[key]
assert getattr(b.schunk.cparams, key) == cparams2[key]
assert b.chunks == tuple(chunks2)
assert b.blocks == tuple(blocks2)
assert a.dtype == b.dtype
Expand Down
8 changes: 4 additions & 4 deletions tests/ndarray/test_empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ def test_empty(shape, chunks, blocks, dtype, cparams, urlpath, contiguous):
assert a.blocks == blocks
assert a.dtype == dtype
assert a.schunk.typesize == dtype.itemsize
assert a.schunk.cparams["codec"] == cparams["codec"]
assert a.schunk.cparams["clevel"] == cparams["clevel"]
assert a.schunk.cparams["filters"][: len(filters)] == filters
assert a.schunk.dparams["nthreads"] == 2
assert a.schunk.cparams.codec == cparams["codec"]
assert a.schunk.cparams.clevel == cparams["clevel"]
assert a.schunk.cparams.filters[: len(filters)] == filters
assert a.schunk.dparams.nthreads == 2

blosc2.remove_urlpath(urlpath)

Expand Down
3 changes: 2 additions & 1 deletion tests/ndarray/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import pytest
from dataclasses import asdict

import blosc2

Expand Down Expand Up @@ -74,7 +75,7 @@ def test_full(shape, chunks, blocks, fill_value, cparams, dparams, dtype, urlpat
urlpath=urlpath,
contiguous=contiguous,
)
assert a.schunk.dparams == dparams
assert asdict(a.schunk.dparams) == dparams
if isinstance(fill_value, bytes):
dtype = np.dtype(f"S{len(fill_value)}")
assert a.dtype == np.dtype(dtype) if dtype is not None else np.dtype(np.uint8)
Expand Down
4 changes: 2 additions & 2 deletions tests/ndarray/test_lazyexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,8 @@ def test_params(array_fixture):
res = expr.eval(urlpath=urlpath, cparams=cparams, dparams=dparams, chunks=chunks, blocks=blocks)
np.testing.assert_allclose(res[:], nres)
assert res.schunk.urlpath == urlpath
assert res.schunk.cparams["nthreads"] == cparams["nthreads"]
assert res.schunk.dparams["nthreads"] == dparams["nthreads"]
assert res.schunk.cparams.nthreads == cparams["nthreads"]
assert res.schunk.dparams.nthreads == dparams["nthreads"]
assert res.chunks == chunks
assert res.blocks == blocks

Expand Down
12 changes: 6 additions & 6 deletions tests/ndarray/test_lazyudf.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_params(chunked_eval):
res = expr.eval(urlpath=urlpath2, chunks=(10,))
np.testing.assert_allclose(res[...], npc)
assert res.shape == npa.shape
assert res.schunk.cparams["nthreads"] == cparams["nthreads"]
assert res.schunk.cparams.nthreads == cparams["nthreads"]
assert res.schunk.urlpath == urlpath2
assert res.chunks == (10,)

Expand Down Expand Up @@ -243,7 +243,7 @@ def test_getitem(shape, chunks, blocks, slices, urlpath, contiguous, chunked_eva
assert res.schunk.urlpath is None
assert res.schunk.contiguous == contiguous
# Check dparams after a getitem and an eval
assert res.schunk.dparams["nthreads"] == dparams["nthreads"]
assert res.schunk.dparams.nthreads == dparams["nthreads"]

lazy_eval = expr[slices]
np.testing.assert_allclose(lazy_eval, npc[slices])
Expand Down Expand Up @@ -282,8 +282,8 @@ def test_eval_slice(shape, chunks, blocks, slices, urlpath, contiguous, chunked_
np.testing.assert_allclose(res[...], npc[slices])
assert res.schunk.urlpath is None
assert res.schunk.contiguous == contiguous
assert res.schunk.dparams["nthreads"] == dparams["nthreads"]
assert res.schunk.cparams["nthreads"] == blosc2.cparams_dflts["nthreads"]
assert res.schunk.dparams.nthreads == dparams["nthreads"]
assert res.schunk.cparams.nthreads == blosc2.cparams_dflts["nthreads"]
assert res.shape == npc[slices].shape

cparams = {"nthreads": 6}
Expand All @@ -294,8 +294,8 @@ def test_eval_slice(shape, chunks, blocks, slices, urlpath, contiguous, chunked_
np.testing.assert_allclose(res[...], npc[slices])
assert res.schunk.urlpath == urlpath2
assert res.schunk.contiguous == contiguous
assert res.schunk.dparams["nthreads"] == dparams["nthreads"]
assert res.schunk.cparams["nthreads"] == cparams["nthreads"]
assert res.schunk.dparams.nthreads == dparams["nthreads"]
assert res.schunk.cparams.nthreads == cparams["nthreads"]
assert res.shape == npc[slices].shape

blosc2.remove_urlpath(urlpath)
Expand Down
4 changes: 2 additions & 2 deletions tests/ndarray/test_lossy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def test_lossy(shape, cparams, dtype, urlpath, contiguous):
a = blosc2.asarray(array, cparams=cparams, urlpath=urlpath, contiguous=contiguous, mode="w")

if (
a.schunk.cparams["codec"] in (blosc2.Codec.ZFP_RATE, blosc2.Codec.ZFP_PREC, blosc2.Codec.ZFP_ACC)
or a.schunk.cparams["filters"][0] == blosc2.Filter.NDMEAN
a.schunk.cparams.codec in (blosc2.Codec.ZFP_RATE, blosc2.Codec.ZFP_PREC, blosc2.Codec.ZFP_ACC)
or a.schunk.cparams.filters[0] == blosc2.Filter.NDMEAN
):
_ = a[...]
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def test_open(contiguous, urlpath, cparams, dparams, nchunks, chunk_nitems, dtyp
cparams2 = cparams
cparams2["nthreads"] = 1
schunk_open = blosc2.open(urlpath, mode, mmap_mode=mmap_mode, cparams=cparams2)
assert schunk_open.cparams["nthreads"] == cparams2["nthreads"]
assert schunk_open.cparams.nthreads == cparams2["nthreads"]

for key in cparams:
if key == "nthreads":
continue
assert schunk_open.cparams[key] == cparams[key]
assert getattr(schunk_open.cparams, key) == cparams[key]

buffer = np.zeros(chunk_nitems, dtype=dtype)
if mode != "r":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_postfilters.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def postf2(input, output, offset):
def postf3(input, output, offset):
output[:] = input <= np.datetime64("1997-12-31")

schunk.dparams = {"nthreads": 1}
schunk.dparams = blosc2.DParams(nthreads=1)
post_data = np.empty(chunk_len * nchunks, dtype=output_dtype)
schunk.get_slice(0, chunk_len * nchunks, out=post_data)

Expand Down
7 changes: 5 additions & 2 deletions tests/test_prefilters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import pytest
from dataclasses import asdict, replace

import blosc2

Expand Down Expand Up @@ -104,7 +105,7 @@ def fill_f4(inputs_tuple, output, offset):

fill_f4((data, data2, np.pi), res, offset)

new_cparams = {"nthreads": 2}
new_cparams = replace(schunk.cparams, nthreads=2)
schunk.cparams = new_cparams

pre_data = np.empty(chunk_len * nchunks, dtype=schunk_dtype)
Expand Down Expand Up @@ -180,7 +181,9 @@ def pref2(input, output, offset):
def pref3(input, output, offset):
output[:] = input <= np.datetime64("1997-12-31")

schunk.cparams = {"nthreads": 1}
new_cparams = asdict(schunk.cparams)
new_cparams["nthreads"] = 1
schunk.cparams = blosc2.CParams(**new_cparams)

schunk[: nchunks * chunk_len] = data
post_data = np.empty(chunk_len * nchunks, dtype=schunk_dtype)
Expand Down
Loading

0 comments on commit e3f128f

Please sign in to comment.