From c01a95bf128baf07f4ec1c9be4b260a91c4d6ca1 Mon Sep 17 00:00:00 2001 From: Marek Otahal Date: Thu, 11 Jan 2018 14:22:38 +0100 Subject: [PATCH] SP: add check for unstable params in constructor on each initialization, SP is tested to produce stable output. --- src/nupic/algorithms/SpatialPooler.cpp | 14 ++++++++++++++ src/test/unit/algorithms/SpatialPoolerTest.cpp | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/nupic/algorithms/SpatialPooler.cpp b/src/nupic/algorithms/SpatialPooler.cpp index bd8b366173..39f29992c3 100644 --- a/src/nupic/algorithms/SpatialPooler.cpp +++ b/src/nupic/algorithms/SpatialPooler.cpp @@ -488,6 +488,17 @@ const vector& SpatialPooler::getBoostedOverlaps() const return boostedOverlaps_; } +/** helper method that checks SP output is stable for given configuration */ +bool checkUnstableParams_(SpatialPooler &sp) { + vector input(sp.getNumInputs(), 1); //auto size of input + vector out1(sp.getNumColumns(), 0); + vector out2(sp.getNumColumns(), 0); + sp.compute(input.data(), true, out1.data()); + sp.compute(input.data(), true, out2.data()); + //TODO should we add SP.reset() and call it here? + return std::equal(std::begin(out1), std::end(out1), std::begin(out2)); //compare all, element wise +} + void SpatialPooler::initialize(vector inputDimensions, vector columnDimensions, UInt potentialRadius, @@ -593,6 +604,9 @@ void SpatialPooler::initialize(vector inputDimensions, printParameters(); std::cout << "CPP SP seed = " << seed << std::endl; } + + //check for reasonable params + NTA_CHECK(checkUnstableParams_(*this)); //TODO the assert runs only at debug builds, make mandatory? } void SpatialPooler::compute(UInt inputArray[], bool learn, diff --git a/src/test/unit/algorithms/SpatialPoolerTest.cpp b/src/test/unit/algorithms/SpatialPoolerTest.cpp index c873650186..a380d1b866 100644 --- a/src/test/unit/algorithms/SpatialPoolerTest.cpp +++ b/src/test/unit/algorithms/SpatialPoolerTest.cpp @@ -2244,7 +2244,7 @@ namespace { vector out2(sp.getNumColumns(), 0); sp.compute(input.data(), true, out1.data()); sp.compute(input.data(), true, out2.data()); - EXPECT_EQ(out1, out2); + EXPECT_EQ(out1, out2); //not necessary with the check in SP initialize(), but keep here as example }