Skip to content

Commit

Permalink
Merge pull request #290 from arcondello/response_serialize
Browse files Browse the repository at this point in the history
SampleSet.to_serializable and SampleSet.from_serializable
  • Loading branch information
arcondello authored Oct 11, 2018
2 parents b1b4741 + 707d68b commit 4f5d7ab
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 7 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include dimod/io/bqm_json_schema.json
include dimod/io/sampleset_json_schema.json
recursive-include dimod/roof_duality/src *.hpp *.cpp
recursive-include dimod/roof_duality *.cpp *.pyx
96 changes: 90 additions & 6 deletions dimod/io/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,30 @@
from __future__ import absolute_import

import json
import base64
import operator

from functools import reduce
from pkg_resources import resource_filename

import jsonschema
import numpy as np

from six import iteritems

from dimod.binary_quadratic_model import BinaryQuadraticModel
from dimod.package_info import __version__
from dimod.response import Response
from dimod.sampleset import SampleSet
from dimod.vartypes import Vartype

bqm_json_schema_version = "1.0.0"
json_schema_version = "1.0.0"

with open(resource_filename(__name__, 'bqm_json_schema.json'), 'r') as schema_file:
bqm_json_schema = json.load(schema_file)

with open(resource_filename(__name__, 'sampleset_json_schema.json'), 'r') as schema_file:
sampleset_json_schema = json.load(schema_file)


def _decode_label(label):
"""Convert a list label into a tuple. Works recursively on nested lists."""
Expand All @@ -48,6 +56,55 @@ def _encode_label(label):
return label


def _pack_record(record):
doc = {}
for field in record.dtype.fields:
dat = record[field]
if field == 'sample':
binary = np.packbits(dat > 0).tobytes()
else:
binary = dat.tobytes()

doc[field] = {'data': (base64.b64encode(binary)).decode("UTF-8"),
'shape': dat.shape,
'dtype': str(dat.dtype)}
return doc


def _prod(iterable):
return reduce(operator.mul, iterable, 1)


def _unpack_record(obj, vartype):
fields = {}
datatypes = []

for field, data in obj.items():

shape = tuple(data['shape'])
dtype = data['dtype']

if field == 'sample':
raw = np.unpackbits(np.frombuffer(base64.b64decode(data['data']), dtype=np.uint8))
arr = raw[:_prod(shape)].astype(dtype).reshape(shape)

if vartype is Vartype.SPIN:
arr = 2 * arr - 1
else:
raw = np.frombuffer(base64.b64decode(data['data']), dtype=dtype)
arr = raw[:_prod(shape)].reshape(shape)

fields[field] = arr
datatypes.append((field, dtype, shape[1:]))

record = np.rec.array(np.zeros(shape[0], dtype=datatypes))

for field, arr in fields.items():
record[field] = arr

return record


def bqm_decode_hook(dct, cls=None):
"""Decode hook as can be used with json.loads."""

Expand All @@ -68,6 +125,22 @@ def bqm_decode_hook(dct, cls=None):
return dct


def sampleset_decode_hook(dct, cls=None):
"""Decode hook as can be used with json.loads."""

if cls is None:
cls = SampleSet

if jsonschema.Draft4Validator(sampleset_json_schema).is_valid(dct):
# SampleSet

vartype = Vartype[dct['variable_type']]
record = _unpack_record(dct['record'], vartype)
return cls(record, dct['variable_labels'], dct['info'], vartype)

return dct


class DimodEncoder(json.JSONEncoder):
"""Subclass the JSONEncoder for dimod objects."""
def default(self, obj):
Expand All @@ -85,14 +158,25 @@ def default(self, obj):
"quadratic_terms": list(self._quadratic_biases(obj.quadratic)),
"offset": obj.offset,
"variable_type": vartype_string,
"version": {"dimod": __version__, "bqm_schema": bqm_json_schema_version},
"version": {"dimod": __version__, "bqm_schema": json_schema_version},
"variable_labels": list(self._variable_labels(obj.linear)),
"info": obj.info}
return json_dict

elif isinstance(obj, Response):
# we will eventually want to implement this
raise NotImplementedError
elif isinstance(obj, SampleSet):

if obj.vartype is Vartype.SPIN:
vartype_string = 'SPIN'
elif obj.vartype is Vartype.BINARY:
vartype_string = 'BINARY'
else:
raise RuntimeError("unknown vartype")

return {"record": _pack_record(obj.record),
"variable_type": vartype_string,
"info": obj.info,
"version": {"dimod": __version__, "sampleset_schema": json_schema_version},
"variable_labels": list(obj.variables)}

return json.JSONEncoder.default(self, obj)

Expand Down
41 changes: 41 additions & 0 deletions dimod/io/sampleset_json_schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"$schema": "http://json-schema.org/draft-04/schema#",
"title": "sample set schema",
"type": "object",
"required": ["record",
"info",
"variable_labels",
"variable_type",
"version"],
"properties": {
"record": {
"type": "object",
"required": ["sample",
"energy",
"num_occurrences"]
},
"info": {
"type": "object"
},
"variable_labels": {
"type": "array",
"items": {
"type": ["integer", "string", "array"],
"minimum": 0
}
},
"variable_type": {
"type":"string",
"enum":["SPIN", "BINARY"]
},
"version": {
"type": "object",
"required": ["sampleset_schema", "dimod"],
"properties": {
"sampleset_schema": {
"enum":["1.0.0"]
}
}
}
}
}
58 changes: 57 additions & 1 deletion dimod/sampleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def __eq__(self, other):
return False

# check that all the fields match in record, order doesn't matter
if self.record.dtype.fields != other.record.dtype.fields:
if self.record.dtype.fields.keys() != other.record.dtype.fields.keys():
return False
for field in self.record.dtype.fields:
if field == 'sample':
Expand Down Expand Up @@ -711,6 +711,62 @@ def relabel_variables(self, mapping, inplace=True):
self._variables = VariableIndexView(mapping.get(v, v) for v in self.variable_labels)
return self

###############################################################################################
# Serialization
###############################################################################################

def to_serializable(self):
"""Convert a sample set to a serializable object
Returns:
dict: An object that can be serialized.
Examples:
Encode using JSON
>>> import dimod
>>> import json
...
>>> samples = dimod.SampleSet.from_samples([-1, 1, -1], dimod.SPIN, energy=-.5)
>>> s = json.dumps(samples.to_serializable())
See also:
:meth:`~.SampleSet.from_serializable`
"""
from dimod.io.json import DimodEncoder
return DimodEncoder().default(self)

@classmethod
def from_serializable(cls, obj):
"""Deserialize a sample set.
Args:
obj (dict):
A sample set serialized by :meth:`~.SampleSet.to_serializable`.
Returns:
:obj:`.SampleSet`
Examples:
Encode and decode using JSON
>>> import dimod
>>> import json
...
>>> samples = dimod.SampleSet.from_samples([-1, 1, -1], dimod.SPIN, energy=-.5)
>>> s = json.dumps(samples.to_serializable())
>>> new_samples = dimod.SampleSet.from_serializable(json.loads(s))
See also:
:meth:`~.SampleSet.to_serializable`
"""
from dimod.io.json import sampleset_decode_hook
return sampleset_decode_hook(obj, cls=cls)


def _samples_dicts_to_array(samples_dicts):
"""Convert an iterable of samples where each sample is a dict to a numpy 2d array. Also
Expand Down
44 changes: 44 additions & 0 deletions tests/test_sampleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#
# ================================================================================================
import unittest
import json

import numpy as np

Expand Down Expand Up @@ -81,3 +82,46 @@ def test_eq_ordered(self):
self.assertEqual(ss0, ss1)
self.assertNotEqual(ss0, ss2)
self.assertNotEqual(ss1, ss3)


class TestSampleSetSerialization(unittest.TestCase):

def test_functional_simple_shapes(self):
for ns in range(1, 9):
for nv in range(1, 15):

raw = np.random.randint(2, size=(ns, nv))

if ns % 2:
vartype = dimod.SPIN
raw = 2 * raw - 1
else:
vartype = dimod.BINARY

samples = dimod.SampleSet.from_samples(raw, vartype, energy=np.ones(ns))
new_samples = dimod.SampleSet.from_serializable(samples.to_serializable())
self.assertEqual(samples, new_samples)

def test_functional_json(self):
nv = 4
ns = 7

raw = np.random.randint(2, size=(ns, nv))

samples = dimod.SampleSet.from_samples(raw, dimod.BINARY, energy=np.ones(ns))

s = json.dumps(samples.to_serializable())
new_samples = dimod.SampleSet.from_serializable(json.loads(s))
self.assertEqual(samples, new_samples)

def test_functional_str(self):
nv = 4
ns = 7

raw = np.random.randint(2, size=(ns, nv))

samples = dimod.SampleSet.from_samples((raw, 'abcd'), dimod.BINARY, energy=np.ones(ns))

s = json.dumps(samples.to_serializable())
new_samples = dimod.SampleSet.from_serializable(json.loads(s))
self.assertEqual(samples, new_samples)

0 comments on commit 4f5d7ab

Please sign in to comment.