Skip to content

Commit

Permalink
Fix issues upgrading state leading to SIGSEGV
Browse files Browse the repository at this point in the history
  • Loading branch information
tveasey committed Jul 2, 2018
1 parent e16816e commit a341300
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 21 deletions.
5 changes: 4 additions & 1 deletion include/core/CStateMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <atomic>
#include <cstddef>
#include <list>
#include <map>
#include <vector>

namespace ml {
Expand Down Expand Up @@ -67,6 +68,7 @@ class CORE_EXPORT CStateMachine {
using TSizeVec = std::vector<std::size_t>;
using TSizeVecVec = std::vector<TSizeVec>;
using TStrVec = std::vector<std::string>;
using TSizeSizeMap = std::map<std::size_t, std::size_t>;

public:
//! Set the number of machines we expect the program to use.
Expand All @@ -85,7 +87,8 @@ class CORE_EXPORT CStateMachine {
//! \name Persistence
//@{
//! Initialize by reading state from \p traverser.
bool acceptRestoreTraverser(CStateRestoreTraverser& traverser);
bool acceptRestoreTraverser(CStateRestoreTraverser& traverser,
const TSizeSizeMap& mapping = TSizeSizeMap());

//! Persist state by passing information to the supplied inserter.
void acceptPersistInserter(CStatePersistInserter& inserter) const;
Expand Down
22 changes: 16 additions & 6 deletions lib/core/CStateMachine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace core {
namespace {

// CStateMachine
const std::string MACHINE_TAG("a");
//const std::string MACHINE_TAG("a"); No longer used
const std::string STATE_TAG("b");

// CStateMachine::SMachine
Expand Down Expand Up @@ -88,17 +88,26 @@ CStateMachine CStateMachine::create(const TStrVec& alphabet,
return result;
}

bool CStateMachine::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
bool CStateMachine::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser,
const TSizeSizeMap& mapping) {
do {
const std::string& name = traverser.name();
RESTORE_BUILT_IN(MACHINE_TAG, m_Machine)
RESTORE_BUILT_IN(STATE_TAG, m_State)
} while (traverser.next());
if (mapping.size() > 0) {
auto mapped = mapping.find(m_State);
if (mapped != mapping.end()) {
m_State = mapped->second;
} else {
LOG_ERROR(<< "Bad mapping '" << core::CContainerPrinter::print(mapping)
<< "' state = " << m_State);
return false;
}
}
return true;
}

void CStateMachine::acceptPersistInserter(core::CStatePersistInserter& inserter) const {
inserter.insertValue(MACHINE_TAG, m_Machine);
inserter.insertValue(STATE_TAG, m_State);
}

Expand Down Expand Up @@ -201,14 +210,15 @@ void CStateMachine::CMachineDeque::capacity(std::size_t capacity) {
m_Capacity = capacity;
}

const CStateMachine::SMachine& CStateMachine::CMachineDeque::operator[](std::size_t pos) const {
const CStateMachine::SMachine& CStateMachine::CMachineDeque::operator[](std::size_t pos_) const {
std::size_t pos{pos_};
for (const auto& machines : m_Machines) {
if (pos < machines.size()) {
return machines[pos];
}
pos -= machines.size();
}
LOG_ABORT(<< "Invalid index '" << pos << "'");
LOG_ABORT(<< "Invalid index '" << pos_ << "'");
}

std::size_t CStateMachine::CMachineDeque::size() const {
Expand Down
7 changes: 4 additions & 3 deletions lib/core/unittest/CStateMachineTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,11 @@ void CStateMachineTest::testPersist() {
core::CRapidXmlStateRestoreTraverser traverser(parser);

core::CStateMachine restored = core::CStateMachine::create(
machine[1].s_Alphabet, machine[1].s_States, machine[1].s_TransitionFunction,
machine[0].s_Alphabet, machine[0].s_States, machine[0].s_TransitionFunction,
0); // initial state
traverser.traverseSubLevel(
boost::bind(&core::CStateMachine::acceptRestoreTraverser, &restored, _1));
traverser.traverseSubLevel([&restored](core::CStateRestoreTraverser& traverser_) {
return restored.acceptRestoreTraverser(traverser_);
});

CPPUNIT_ASSERT_EQUAL(original.checksum(), restored.checksum());
std::string newXml;
Expand Down
35 changes: 24 additions & 11 deletions lib/maths/CTimeSeriesDecompositionDetail.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

#include <algorithm>
#include <cmath>
#include <map>
#include <numeric>
#include <string>
#include <vector>
Expand All @@ -61,6 +62,7 @@ using TBoolVec = std::vector<bool>;
using TDoubleVec = std::vector<double>;
using TSizeVec = std::vector<std::size_t>;
using TSizeVecVec = std::vector<TSizeVec>;
using TSizeSizeMap = std::map<std::size_t, std::size_t>;
using TStrVec = std::vector<std::string>;
using TTimeVec = std::vector<core_t::TTime>;
using TTimeTimePr = std::pair<core_t::TTime, core_t::TTime>;
Expand Down Expand Up @@ -319,7 +321,7 @@ const std::string LAST_UPDATE_OLD_TAG{"j"};

//////////////////////// Upgrade to Version 6.3 ////////////////////////

const double MODEL_WEIGHT_UPGRADING_TO_VERSION_6p3{48.0};
const double MODEL_WEIGHT_UPGRADING_TO_VERSION_6_3{48.0};

bool upgradeTrendModelToVersion6p3(const core_t::TTime bucketLength,
CTrendComponent& trend,
Expand All @@ -342,7 +344,7 @@ bool upgradeTrendModelToVersion6p3(const core_t::TTime bucketLength,

// Generate some samples from the old trend model.

double weight{MODEL_WEIGHT_UPGRADING_TO_VERSION_6p3 *
double weight{MODEL_WEIGHT_UPGRADING_TO_VERSION_6_3 *
static_cast<double>(bucketLength) / static_cast<double>(4 * WEEK)};

CPRNG::CXorOShiro128Plus rng;
Expand All @@ -355,6 +357,8 @@ bool upgradeTrendModelToVersion6p3(const core_t::TTime bucketLength,
return true;
}

const TSizeSizeMap SC_STATES_UPGRADING_TO_VERSION_6_3{{0, 0}, {1, 1}, {2, 1}, {3, 2}, {4, 3}};

////////////////////////////////////////////////////////////////////////

// Constants
Expand Down Expand Up @@ -490,8 +494,9 @@ bool CTimeSeriesDecompositionDetail::CPeriodicityTest::acceptRestoreTraverser(
do {
const std::string& name{traverser.name()};
RESTORE(PERIODICITY_TEST_MACHINE_6_3_TAG,
traverser.traverseSubLevel(boost::bind(
&core::CStateMachine::acceptRestoreTraverser, &m_Machine, _1)))
traverser.traverseSubLevel([this](core::CStateRestoreTraverser& traverser_) {
return m_Machine.acceptRestoreTraverser(traverser_);
}))
RESTORE_SETUP_TEARDOWN(
SHORT_WINDOW_6_3_TAG, m_Windows[E_Short].reset(this->newWindow(E_Short)),
m_Windows[E_Short] && traverser.traverseSubLevel(boost::bind(
Expand Down Expand Up @@ -792,8 +797,9 @@ bool CTimeSeriesDecompositionDetail::CCalendarTest::acceptRestoreTraverser(core:
do {
const std::string& name{traverser.name()};
RESTORE(CALENDAR_TEST_MACHINE_6_3_TAG,
traverser.traverseSubLevel(boost::bind(
&core::CStateMachine::acceptRestoreTraverser, &m_Machine, _1)))
traverser.traverseSubLevel([this](core::CStateRestoreTraverser& traverser_) {
return m_Machine.acceptRestoreTraverser(traverser_);
}))
RESTORE_BUILT_IN(LAST_MONTH_6_3_TAG, m_LastMonth);
RESTORE_SETUP_TEARDOWN(
CALENDAR_TEST_6_3_TAG,
Expand Down Expand Up @@ -999,8 +1005,9 @@ bool CTimeSeriesDecompositionDetail::CComponents::acceptRestoreTraverser(
while (traverser.next()) {
const std::string& name{traverser.name()};
RESTORE(COMPONENTS_MACHINE_6_3_TAG,
traverser.traverseSubLevel(boost::bind(
&core::CStateMachine::acceptRestoreTraverser, &m_Machine, _1)));
traverser.traverseSubLevel([this](core::CStateRestoreTraverser& traverser_) {
return m_Machine.acceptRestoreTraverser(traverser_);
}))
RESTORE_BUILT_IN(DECAY_RATE_6_3_TAG, m_DecayRate);
RESTORE(GAIN_CONTROLLER_6_3_TAG,
traverser.traverseSubLevel(boost::bind(&CGainController::acceptRestoreTraverser,
Expand Down Expand Up @@ -1035,8 +1042,10 @@ bool CTimeSeriesDecompositionDetail::CComponents::acceptRestoreTraverser(
do {
const std::string& name{traverser.name()};
RESTORE(COMPONENTS_MACHINE_OLD_TAG,
traverser.traverseSubLevel(boost::bind(
&core::CStateMachine::acceptRestoreTraverser, &m_Machine, _1)));
traverser.traverseSubLevel([this](core::CStateRestoreTraverser& traverser_) {
return m_Machine.acceptRestoreTraverser(
traverser_, SC_STATES_UPGRADING_TO_VERSION_6_3);
}))
RESTORE_SETUP_TEARDOWN(TREND_OLD_TAG,
/**/,
traverser.traverseSubLevel(boost::bind(
Expand All @@ -1057,7 +1066,7 @@ bool CTimeSeriesDecompositionDetail::CComponents::acceptRestoreTraverser(
/**/)
} while (traverser.next());

m_MeanVarianceScale.add(1.0, MODEL_WEIGHT_UPGRADING_TO_VERSION_6p3);
m_MeanVarianceScale.add(1.0, MODEL_WEIGHT_UPGRADING_TO_VERSION_6_3);
}
return true;
}
Expand Down Expand Up @@ -1951,13 +1960,15 @@ bool CTimeSeriesDecompositionDetail::CComponents::CSeasonal::acceptRestoreTraver
RESTORE_NO_ERROR(COMPONENT_6_3_TAG,
m_Components.emplace_back(decayRate, bucketLength, traverser))
}
m_PredictionErrors.resize(m_Components.size());
} else {
// There is no version string this is historic state.
do {
const std::string& name{traverser.name()};
RESTORE_NO_ERROR(COMPONENT_OLD_TAG,
m_Components.emplace_back(decayRate, bucketLength, traverser))
} while (traverser.next());
m_PredictionErrors.resize(m_Components.size());
}
return true;
}
Expand Down Expand Up @@ -2253,13 +2264,15 @@ bool CTimeSeriesDecompositionDetail::CComponents::CCalendar::acceptRestoreTraver
RESTORE_NO_ERROR(COMPONENT_6_3_TAG,
m_Components.emplace_back(decayRate, bucketLength, traverser))
}
m_PredictionErrors.resize(m_Components.size());
} else {
// There is no version string this is historic state.
do {
const std::string& name{traverser.name()};
RESTORE_NO_ERROR(COMPONENT_OLD_TAG,
m_Components.emplace_back(decayRate, bucketLength, traverser))
} while (traverser.next());
m_PredictionErrors.resize(m_Components.size());
}
return true;
}
Expand Down
16 changes: 16 additions & 0 deletions lib/maths/unittest/CTimeSeriesDecompositionTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2051,6 +2051,8 @@ void CTimeSeriesDecompositionTest::testUpgrade() {
// Check we can validly upgrade existing state.

using TStrVec = std::vector<std::string>;
using TDouble3Vec = core::CSmallVector<double, 3>;

auto load = [](const std::string& name, std::string& result) {
std::ifstream file;
file.open(name);
Expand Down Expand Up @@ -2126,6 +2128,13 @@ void CTimeSeriesDecompositionTest::testUpgrade() {
CPPUNIT_ASSERT_DOUBLES_EQUAL(expectedScale.second, scale.second,
0.005 * std::max(expectedScale.second, 0.4));
}

// Check some basic operations on the upgraded model.
decomposition.forecast(60480000, 60480000 + WEEK, HALF_HOUR, 90.0, 1.0,
[](core_t::TTime, const TDouble3Vec&) {});
for (core_t::TTime time = 60480000; time < 60480000 + WEEK; time += HALF_HOUR) {
decomposition.addPoint(time, 10.0);
}
}

LOG_DEBUG(<< "*** Trend and Seasonal Components ***");
Expand Down Expand Up @@ -2201,6 +2210,13 @@ void CTimeSeriesDecompositionTest::testUpgrade() {
LOG_DEBUG(<< "Mean scale error = " << maths::CBasicStatistics::mean(meanScaleError));
CPPUNIT_ASSERT(maths::CBasicStatistics::mean(meanValueError) < 0.06);
CPPUNIT_ASSERT(maths::CBasicStatistics::mean(meanScaleError) < 0.07);

// Check some basic operations on the upgraded model.
decomposition.forecast(10366200, 10366200 + WEEK, HALF_HOUR, 90.0, 1.0,
[](core_t::TTime, const TDouble3Vec&) {});
for (core_t::TTime time = 60480000; time < 60480000 + WEEK; time += HALF_HOUR) {
decomposition.addPoint(time, 10.0);
}
}
}

Expand Down

0 comments on commit a341300

Please sign in to comment.