Skip to content

Commit

Permalink
use reconstructor function for pickling, see also issue #206
Browse files Browse the repository at this point in the history
  • Loading branch information
ilanschnell committed Aug 4, 2023
1 parent 73bfe6e commit 788c188
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 15 deletions.
1 change: 1 addition & 0 deletions bitarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import absolute_import

from bitarray._bitarray import (bitarray, decodetree, _sysinfo,
_bitarray_reconstructor,
get_default_endian, _set_default_endian,
__version__)

Expand Down
80 changes: 65 additions & 15 deletions bitarray/_bitarray.c
Original file line number Diff line number Diff line change
Expand Up @@ -1163,8 +1163,20 @@ static PyObject *
bitarray_reduce(bitarrayobject *self)
{
const Py_ssize_t nbytes = Py_SIZE(self);
PyObject *dict, *repr = NULL, *result = NULL;
char *str;
static PyObject *reconstructor = NULL;
PyObject *dict, *bytes, *result = NULL;

if (reconstructor == NULL) {
PyObject *bitarray_module;

if ((bitarray_module = PyImport_ImportModule("bitarray")) == NULL)
return NULL;
reconstructor = PyObject_GetAttrString(bitarray_module,
"_bitarray_reconstructor");
Py_DECREF(bitarray_module);
if (reconstructor == NULL)
return NULL;
}

dict = PyObject_GetAttrString((PyObject *) self, "__dict__");
if (dict == NULL) {
Expand All @@ -1173,21 +1185,17 @@ bitarray_reduce(bitarrayobject *self)
Py_INCREF(dict);
}

repr = PyBytes_FromStringAndSize(NULL, nbytes + 1);
if (repr == NULL)
goto error;

str = PyBytes_AsString(repr);
/* first byte contains the number of pad bits */
*str = (char) set_padbits(self);
/* remaining bytes contain buffer */
memcpy(str + 1, self->ob_item, (size_t) nbytes);
bytes = PyBytes_FromStringAndSize(NULL, nbytes);
if (bytes == NULL) {
Py_DECREF(dict);
return NULL;
}
memcpy(PyBytes_AsString(bytes), self->ob_item, (size_t) nbytes);

result = Py_BuildValue("O(Os)O", Py_TYPE(self),
repr, ENDIAN_STR(self->endian), dict);
error:
result = Py_BuildValue(
"O(OnOsi)O", reconstructor, Py_TYPE(self), self->nbits, bytes,
ENDIAN_STR(self->endian), self->readonly, dict);
Py_DECREF(dict);
Py_XDECREF(repr);
return result;
}

Expand Down Expand Up @@ -4001,6 +4009,45 @@ static PyTypeObject Bitarray_Type = {

/***************************** Module functions ***************************/

static PyObject *
reconstructor(PyObject *module, PyObject *args)
{
PyTypeObject *type;
Py_ssize_t nbits, nbytes;
PyObject *res, *bytes;
char *endian_str;
int endian, readonly;

if (!PyArg_ParseTuple(args, "OnOsi:_bitarray_reconstructor",
&type, &nbits, &bytes, &endian_str, &readonly))
return NULL;

if ((endian = endian_from_string(endian_str)) < 0)
return NULL;

if (!PyBytes_Check(bytes))
return PyErr_Format(PyExc_TypeError, "bytes expected, got '%s'",
Py_TYPE(bytes)->tp_name);

nbytes = PyBytes_GET_SIZE(bytes);
if (nbytes != BYTES(nbits))
return PyErr_Format(PyExc_ValueError,
"size mismatch: %zd != %zd (%zd bits)",
nbytes, BYTES(nbits), nbits);

res = newbitarrayobject(type, nbits, endian);
if (res == NULL)
return NULL;
memcpy(((bitarrayobject *) res)->ob_item, PyBytes_AS_STRING(bytes),
(size_t) nbytes);
if (readonly)
((bitarrayobject *) res)->readonly = 1;
return res;
}

PyDoc_STRVAR(reconstructor_doc, "Internal. Used for pickling support.");


static PyObject *
get_default_endian(PyObject *module)
{
Expand Down Expand Up @@ -4079,6 +4126,9 @@ Return tuple containing:\n\


static PyMethodDef module_functions[] = {
{"_bitarray_reconstructor",
(PyCFunction) reconstructor, METH_VARARGS,
reconstructor_doc},
{"get_default_endian", (PyCFunction) get_default_endian, METH_NOARGS,
get_default_endian_doc},
{"_set_default_endian", (PyCFunction) set_default_endian, METH_VARARGS,
Expand Down

0 comments on commit 788c188

Please sign in to comment.