Skip to content
This repository has been archived by the owner on Sep 1, 2023. It is now read-only.

NUP-2354 Implement serialization for TMRegion #3657

Merged
merged 4 commits into from
Jun 2, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/nupic/algorithms/backtracking_tm_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self,
permanenceDec=0.10,
permanenceMax=1.0,
activationThreshold=12,
predictedSegmentDecrement=0,
predictedSegmentDecrement=0.0,
maxSegmentsPerCell=255,
maxSynapsesPerSegment=255,
globalDecay=0.10,
Expand Down Expand Up @@ -82,6 +82,21 @@ def __init__(self,
self.infActiveState = {"t": None}


@classmethod
def read(cls, proto):
"""
Intercepts TemporalMemory deserialization request in order to initialize
`self.infActiveState`

@param proto (DynamicStructBuilder) Proto object

@return (TemporalMemory) TemporalMemory shim instance
"""
tm = super(TMShimMixin, cls).read(proto)
tm.infActiveState = {"t": None}
return tm


def compute(self, bottomUpInput, enableLearn, computeInfOutput=None):
"""
(From `backtracking_tm.py`)
Expand Down Expand Up @@ -171,7 +186,7 @@ def __init__(self,
permanenceDec=0.10,
permanenceMax=1.0,
activationThreshold=12,
predictedSegmentDecrement=0,
predictedSegmentDecrement=0.0,
maxSegmentsPerCell=255,
maxSynapsesPerSegment=255,
globalDecay=0.10,
Expand Down
19 changes: 19 additions & 0 deletions src/nupic/regions/tm_region.capnp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
@0xb9d11462f08c1dee;

using import "/nupic/proto/TemporalMemoryProto.capnp".TemporalMemoryProto;

# Next ID: 11
struct TMRegionProto {
temporalImp @0 :Text;
temporalMemory @1 :TemporalMemoryProto;
columnCount @2 :UInt32;
inputWidth @3 :UInt32;
cellsPerColumn @4 :UInt32;
learningMode @5 :Bool;
inferenceMode @6 :Bool;
anomalyMode @7 :Bool;
topDownMode @8 :Bool;
computePredictedActiveCellIndices @9 :Bool;
orColumnOutputs @10 :Bool;
}

95 changes: 78 additions & 17 deletions src/nupic/regions/tm_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,23 @@

import numpy
import os

try:
import capnp
except ImportError:
capnp = None

from nupic.bindings.regions.PyRegion import PyRegion

from nupic.algorithms import (anomaly, backtracking_tm, backtracking_tm_cpp,
backtracking_tm_shim)
if capnp:
from nupic.regions.tm_region_capnp import TMRegionProto

from nupic.support import getArgumentDescriptions



gDefaultTemporalImp = 'py'


Expand Down Expand Up @@ -183,7 +194,7 @@ def getConstraints(arg):
cells per column must also be specified and the output size of the region
should be set the same as columnCount""",
accessMode='Read',
dataType='UInt32',
dataType='Bool',
count=1,
constraints='bool'),

Expand All @@ -208,41 +219,44 @@ def getConstraints(arg):
# The last group is for parameters that aren't strictly spatial or temporal
otherSpec = dict(
learningMode=dict(
description='1 if the node is learning (default 1).',
description='True if the node is learning (default True).',
accessMode='ReadWrite',
dataType='UInt32',
dataType='Bool',
count=1,
defaultValue=True,
constraints='bool'),

inferenceMode=dict(
description='1 if the node is inferring (default 0).',
description='True if the node is inferring (default False).',
accessMode='ReadWrite',
dataType='UInt32',
dataType='Bool',
count=1,
defaultValue=False,
constraints='bool'),

computePredictedActiveCellIndices=dict(
description='1 if active and predicted active indices should be computed',
description='True if active and predicted active indices should be computed',
accessMode='Create',
dataType='UInt32',
dataType='Bool',
count=1,
defaultValue=0,
defaultValue=False,
constraints='bool'),

anomalyMode=dict(
description='1 if an anomaly score is being computed',
description='True if an anomaly score is being computed',
accessMode='Create',
dataType='UInt32',
dataType='Bool',
count=1,
defaultValue=0,
defaultValue=False,
constraints='bool'),

topDownMode=dict(
description='1 if the node should do top down compute on the next call '
'to compute into topDownOut (default 0).',
description='True if the node should do top down compute on the next call '
'to compute into topDownOut (default False).',
accessMode='ReadWrite',
dataType='UInt32',
dataType='Bool',
count=1,
defaultValue=False,
constraints='bool'),

activeOutputCount=dict(
Expand Down Expand Up @@ -319,7 +333,6 @@ def __init__(self,
computePredictedActiveCellIndices=False,

**kwargs):

# Which Temporal implementation?
TemporalClass = _getTPClass(temporalImp)

Expand Down Expand Up @@ -362,7 +375,7 @@ def __init__(self,
self._fpLogTPOutput = None

# Variables set up in initInNetwork()
self._tfdr = None # FDRTemporal instance
self._tfdr = None # FDRTemporal instance


#############################################################################
Expand Down Expand Up @@ -716,7 +729,6 @@ def setParameter(self, parameterName, index, parameterValue):
automatically by PyRegion's parameter set mechanism. The ones that need
special treatment are explicitly handled here.
"""

if parameterName in self._temporalArgNames:
setattr(self._tfdr, parameterName, parameterValue)

Expand All @@ -737,6 +749,7 @@ def setParameter(self, parameterName, index, parameterValue):
else:
raise Exception('Unknown parameter: ' + parameterName)


#############################################################################
#
# Commands
Expand Down Expand Up @@ -773,6 +786,54 @@ def finishLearning(self):
#############################################################################


@staticmethod
def getProtoType():
"""Return the pycapnp proto type that the class uses for serialization."""
return TMRegionProto


def writeToProto(self, proto):
"""Write state to proto object.

proto: TMRegionProto capnproto object
"""
proto.temporalImp = self.temporalImp
proto.columnCount = self.columnCount
proto.inputWidth = self.inputWidth
proto.cellsPerColumn = self.cellsPerColumn
proto.learningMode = self.learningMode
proto.inferenceMode = self.inferenceMode
proto.anomalyMode = self.anomalyMode
proto.topDownMode = self.topDownMode
proto.computePredictedActiveCellIndices = (
self.computePredictedActiveCellIndices)
proto.orColumnOutputs = self.orColumnOutputs

self._tfdr.write(proto.temporalMemory)


@classmethod
def readFromProto(cls, proto):
"""Read state from proto object.

proto: TMRegionProto capnproto object
"""
instance = cls(proto.columnCount, proto.inputWidth, proto.cellsPerColumn)

instance.temporalImp = proto.temporalImp
instance.learningMode = proto.learningMode
instance.inferenceMode = proto.inferenceMode
instance.anomalyMode = proto.anomalyMode
instance.topDownMode = proto.topDownMode
instance.computePredictedActiveCellIndices = (
proto.computePredictedActiveCellIndices)
instance.orColumnOutputs = proto.orColumnOutputs

instance._tfdr = _getTPClass(proto.temporalImp).read(proto.temporalMemory)

return instance


def __getstate__(self):
"""
Return serializable state. This function will return a version of the
Expand Down
23 changes: 22 additions & 1 deletion tests/integration/nupic/engine/network_checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
import unittest
import numpy

from nupic.regions.sp_region import SPRegion
from nupic.regions.record_sensor import RecordSensor
from nupic.regions.sp_region import SPRegion
from nupic.regions.tm_region import TMRegion

from network_creation_common import createAndRunNetwork

Expand Down Expand Up @@ -66,6 +67,26 @@ def testSPRegion(self):
"Row {0} not equal: {1} vs. {2}".format(i, result1, result2))


@unittest.skipUnless(
capnp, "pycapnp is not installed, skipping serialization test.")
def testTMRegion(self):
results1 = createAndRunNetwork(TMRegion, "bottomUpOut",
checkpointMidway=False,
temporalImp="tm_py")

results2 = createAndRunNetwork(TMRegion, "bottomUpOut",
checkpointMidway=True,
temporalImp="tm_py")

self.assertEqual(len(results1), len(results2))

for i in xrange(len(results1)):
result1 = list(results1[i].nonzero()[0])
result2 = list(results2[i].nonzero()[0])
self.assertEqual(result1, result2,
"Row {0} not equal: {1} vs. {2}".format(i, result1, result2))


def compareArrayResults(self, results1, results2):
self.assertEqual(len(results1), len(results2))

Expand Down