diff --git a/tests/unit/nupic/algorithms/sp_overlap_test.py b/tests/unit/nupic/algorithms/sp_overlap_test.py index 1ca3af5d62..866d518406 100755 --- a/tests/unit/nupic/algorithms/sp_overlap_test.py +++ b/tests/unit/nupic/algorithms/sp_overlap_test.py @@ -130,10 +130,10 @@ def frequency(self, maxval=maxVal, periodic=False, forced=True) # forced: it's strongly recommended to use w>=21, in the example we force skip the check for readibility for y in xrange(numColors): temp = enc.encode(rnd.random()*maxVal) - colors.append(numpy.array(temp, dtype=realDType)) + colors.append(numpy.array(temp, dtype=numpy.uint32)) else: for y in xrange(numColors): - sdr = numpy.zeros(n, dtype=realDType) + sdr = numpy.zeros(n, dtype=numpy.uint32) # Randomly setting w out of n bits to 1 sdr[rnd.sample(xrange(n), w)] = 1 colors.append(sdr) @@ -144,7 +144,7 @@ def frequency(self, for i in xrange(numColors): # TODO: See https://github.com/numenta/nupic/issues/2072 spInput = colors[i] - onCells = numpy.zeros(columnDimensions) + onCells = numpy.zeros(columnDimensions, dtype=numpy.uint32) spImpl.compute(spInput, True, onCells) spOutput.append(onCells.tolist()) activeCoincIndices = set(onCells.nonzero()[0]) diff --git a/tests/unit/nupic/algorithms/spatial_pooler_cpp_unit_test.py b/tests/unit/nupic/algorithms/spatial_pooler_cpp_unit_test.py index 0e7ab81626..9a36ebb1f0 100755 --- a/tests/unit/nupic/algorithms/spatial_pooler_cpp_unit_test.py +++ b/tests/unit/nupic/algorithms/spatial_pooler_cpp_unit_test.py @@ -178,6 +178,24 @@ def testUpdateDutyCycles(self): self.assertEqual(list(resultOverlapArr2), list(trueOverlapArr2)) + def testComputeParametersValidation(self): + sp = SpatialPooler(inputDimensions=[5], columnDimensions=[5]) + inputGood = np.ones(5, dtype=uintDType) + outGood = np.zeros(5, dtype=uintDType) + inputBad = np.ones(5, dtype=realDType) + outBad = np.zeros(5, dtype=realDType) + + # Validate good parameters + sp.compute(inputGood, False, outGood) + + # Validate bad input + with self.assertRaises(RuntimeError): + sp.compute(inputBad, False, outGood) + + # Validate bad output + with self.assertRaises(RuntimeError): + sp.compute(inputGood, False, outBad) + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/nupic/algorithms/spatial_pooler_py_api_test.py b/tests/unit/nupic/algorithms/spatial_pooler_py_api_test.py index 37761c1b2b..11e0e87c57 100755 --- a/tests/unit/nupic/algorithms/spatial_pooler_py_api_test.py +++ b/tests/unit/nupic/algorithms/spatial_pooler_py_api_test.py @@ -42,8 +42,8 @@ def setUp(self): def testCompute(self): # Check that there are no errors in call to compute - inputVector = numpy.ones(5) - activeArray = numpy.zeros(5) + inputVector = numpy.ones(5, dtype=uintType) + activeArray = numpy.zeros(5, dtype=uintType) self.sp.compute(inputVector, True, activeArray)