Skip to content

Commit

Permalink
add SpatialPooler and TemporalMemory pickle/unpickle tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Lupino committed Jun 29, 2019
1 parent d749c95 commit fbef39b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
17 changes: 17 additions & 0 deletions bindings/py/tests/algorithms/spatial_pooler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
from htm.algorithms import SpatialPooler as SP
import numpy as np

try:
import cPickle as pickle # For python 2
except ImportError:
import pickle # For python 3

class SpatialPoolerTest(unittest.TestCase):

def testCompute(self):
Expand Down Expand Up @@ -160,6 +165,18 @@ def testGetConnectedCountsUint32(self):
print("Successfully caught incorrect uint numpy data length")
pass

def testNupicSpatialPoolerPickling(self):
"""Test pickling / unpickling of NuPIC SpatialPooler."""

# Simple test: make sure that dumping / loading works...
sp = SP()
pickledSp = pickle.dumps(sp)

sp2 = pickle.loads(pickledSp)

self.assertEqual(sp.getNumColumns(), sp2.getNumColumns(),
"Simple NuPIC SpatialPooler pickle/unpickle failed.")


if __name__ == "__main__":
unittest.main()
18 changes: 18 additions & 0 deletions bindings/py/tests/temporal_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
import unittest
import pytest

try:
import cPickle as pickle # For python 2
except ImportError:
import pickle # For python 3

class TemporalMemoryBindingsTest(unittest.TestCase):
@pytest.mark.skip(reason="Calling arguments on compute()...another PR")
@staticmethod
Expand All @@ -28,3 +33,16 @@ def testIssue807():

tm = TemporalMemory()
tm.compute(set(), True)

def testNupicTemporalMemoryPickling(self):
"""Test pickling / unpickling of NuPIC TemporalMemory."""
from htm.bindings.algorithms import TemporalMemory

# Simple test: make sure that dumping / loading works...
tm = TemporalMemory(columnDimensions=(16,))
pickledTm = pickle.dumps(tm)

tm2 = pickle.loads(pickledTm)

self.assertEqual(tm.numberOfCells(), tm2.numberOfCells(),
"Simple NuPIC TemporalMemory pickle/unpickle failed.")

0 comments on commit fbef39b

Please sign in to comment.