Skip to content

Commit

Permalink
Merge pull request #297 from arcondello/samplesetkwargs
Browse files Browse the repository at this point in the history
Fix SampleSet.from_samples misnamed variables
  • Loading branch information
arcondello authored Oct 22, 2018
2 parents 39c9e85 + df71384 commit 79c37c6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
12 changes: 2 additions & 10 deletions dimod/sampleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,7 @@ def __init__(self, record, variables, info, vartype):
self._vartype = vartype

@classmethod
def from_samples(cls, samples_like, vartype, energy,
info=None,
num_occurrences=None, **vectors):
def from_samples(cls, samples_like, vartype, energy, info=None, num_occurrences=None, **vectors):
"""Build a SampleSet from raw samples.
Args:
Expand Down Expand Up @@ -261,13 +259,7 @@ def from_samples(cls, samples_like, vartype, energy,
('num_occurrences', num_occurrences.dtype)]
for key, vector in vectors.items():
vectors[key] = vector = np.asarray(vector)

if len(vector.shape) < 1 or vector.shape[0] != num_samples:
msg = ('{} and sample have a mismatched shape {}, {}. They must have the same size '
'in the first axis.').format(kwarg, vector.shape, sample.shape)
raise ValueError(msg)

datatypes.append((kwarg, vector.dtype, vector.shape[1:]))
datatypes.append((key, vector.dtype, vector.shape[1:]))

record = np.rec.array(np.zeros(num_samples, dtype=datatypes))
record['sample'] = samples
Expand Down
20 changes: 20 additions & 0 deletions tests/test_sampleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,26 @@ def test_from_samples_iterator(self):
self.assertEqual(len(ss0), len(ss1))
self.assertEqual(ss0, ss1)

def test_from_samples_fields_single(self):
ss = dimod.SampleSet.from_samples({'a': 1, 'b': -1}, dimod.SPIN, energy=1.0, a=5, b='b')

self.assertIn('a', ss.record.dtype.fields)
self.assertIn('b', ss.record.dtype.fields)
self.assertTrue(all(ss.record.a == [5]))
self.assertTrue(all(ss.record.b == ['b']))

def test_from_samples_fields_multiple(self):
ss = dimod.SampleSet.from_samples(np.ones((2, 5)), dimod.BINARY, energy=[0, 0], a=[-5, 5], b=['a', 'b'])

self.assertIn('a', ss.record.dtype.fields)
self.assertIn('b', ss.record.dtype.fields)
self.assertTrue(all(ss.record.a == [-5, 5]))
self.assertTrue(all(ss.record.b == ['a', 'b']))

def test_mismatched_shapes(self):
with self.assertRaises(ValueError):
dimod.SampleSet.from_samples(np.ones((3, 5)), dimod.SPIN, energy=[5, 5])

def test_eq_ordered(self):
# samplesets should be equal regardless of variable order
ss0 = dimod.SampleSet.from_samples(([-1, 1], 'ab'), dimod.SPIN, energy=0.0)
Expand Down

0 comments on commit 79c37c6

Please sign in to comment.