Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement fletcher32 #412

Merged
merged 10 commits into from
Jan 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions numcodecs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,6 @@
register_codec(VLenUTF8)
register_codec(VLenBytes)
register_codec(VLenArray)

from numcodecs.fletcher32 import Fletcher32
register_codec(Fletcher32)
54 changes: 54 additions & 0 deletions numcodecs/fletcher32.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# cython: boundscheck=False
# cython: wraparound=False
# cython: overflowcheck=False
# cython: cdivision=True

import struct
import numpy as np

from numcodecs.abc import Codec
from numcodecs.compat import ensure_contiguous_ndarray

from libc.stdint cimport uint8_t, uint16_t, uint32_t
martindurant marked this conversation as resolved.
Show resolved Hide resolved

cpdef uint32_t fletcher32(const uint16_t[::1] data):
cdef:
uint32_t sum1 = 0
uint32_t sum2 = 0
int index
int size = data.shape[0]

for index in range(0, size):
sum1 = (sum1 + data[index]) % 0xffff
sum2 = (sum2 + sum1) % 0xffff

return (sum2 << 16) | sum1


class Fletcher32(Codec):
codec_id = "fletcher32"

def encode(self, buf):
buf = ensure_contiguous_ndarray(buf).ravel()
if len(buf) % 2:
# rare, odd size of bytes data only
arr = np.frombuffer(buf.tobytes() + b"\x00", dtype="uint16")
val = fletcher32(arr)
else:
val = fletcher32(buf.view('uint16'))
return buf.tobytes() + struct.pack("<I", val)

def decode(self, buf, out=None):
b = ensure_contiguous_ndarray(buf).view('uint8')
if len(buf) % 2:
# rare, odd size of bytes data only
arr = np.frombuffer(b.tobytes() + b"\x00", dtype="uint16")
val = fletcher32(arr)
else:
val = fletcher32(b[:-4].view('uint16'))
found = b[-4:].view('uint32')[0]
assert val == found
martindurant marked this conversation as resolved.
Show resolved Hide resolved
if out:
out.view("uint8")[:] = b[:-4]
return out
return memoryview(b[:-4])
24 changes: 24 additions & 0 deletions numcodecs/tests/test_fletcher32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np
import pytest

from numcodecs.fletcher32 import Fletcher32, fletcher32


@pytest.mark.parametrize("inval,outval", [
[b"abcdef", 1448095018],
[b"abcdefgh", 3957429649]
])
def test_vectors(inval, outval):
arr = np.array(list(inval), dtype="uint8").view('uint16')
assert fletcher32(arr) == outval


@pytest.mark.parametrize(
"dtype",
["uint8", "int32", "float32"]
)
def test_with_data(dtype):
data = np.empty(100, dtype=dtype)
f = Fletcher32()
arr = np.frombuffer(f.decode(f.encode(data)), dtype=dtype)
assert (arr == data).all()
martindurant marked this conversation as resolved.
Show resolved Hide resolved
28 changes: 27 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,31 @@ def vlen_extension():
return extensions


def fletcher_extension():
info('setting up fletcher32 extension')

extra_compile_args = base_compile_args.copy()
define_macros = []

# setup sources
include_dirs = ['numcodecs']
# define_macros += [('CYTHON_TRACE', '1')]

sources = ['numcodecs/fletcher32.pyx']

# define extension module
extensions = [
Extension('numcodecs.fletcher32',
sources=sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
),
]

return extensions


def compat_extension():
info('setting up compat extension')

Expand Down Expand Up @@ -265,7 +290,8 @@ def run_setup(with_extensions):

if with_extensions:
ext_modules = (blosc_extension() + zstd_extension() + lz4_extension() +
compat_extension() + shuffle_extension() + vlen_extension())
compat_extension() + shuffle_extension() + vlen_extension() +
fletcher_extension())

cmdclass = dict(build_ext=ve_build_ext)
else:
Expand Down