Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tableau merge cols #68

Merged
merged 8 commits into from
Jun 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/configuration/GlobalConfiguration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const double GlobalConfiguration::PIVOT_CHANGE_COLUMN_TOLERANCE = 0.000000001;
const unsigned GlobalConfiguration::DEGRADATION_CHECKING_FREQUENCY = 10;
const double GlobalConfiguration::DEGRADATION_THRESHOLD = 0.1;
const double GlobalConfiguration::ACCEPTABLE_SIMPLEX_PIVOT_THRESHOLD = 0.0001;
const bool GlobalConfiguration::USE_COLUMN_MERGING_EQUATIONS = false;
const double GlobalConfiguration::GAUSSIAN_ELIMINATION_PIVOT_SCALE_THRESHOLD = 0.1;
const unsigned GlobalConfiguration::MAX_SIMPLEX_PIVOT_SEARCH_ITERATIONS = 5;
const unsigned GlobalConfiguration::CONSTRAINT_VIOLATION_THRESHOLD = 20;
Expand Down Expand Up @@ -64,6 +65,7 @@ void GlobalConfiguration::print()
printf( " DEGRADATION_CHECKING_FREQUENCY: %u\n", DEGRADATION_CHECKING_FREQUENCY );
printf( " DEGRADATION_THRESHOLD: %.15lf\n", DEGRADATION_THRESHOLD );
printf( " ACCEPTABLE_SIMPLEX_PIVOT_THRESHOLD: %.15lf\n", ACCEPTABLE_SIMPLEX_PIVOT_THRESHOLD );
printf( " USE_COLUMN_MERGING_EQUATIONS: %s\n", USE_COLUMN_MERGING_EQUATIONS ? "Yes" : "No" );
printf( " GAUSSIAN_ELIMINATION_PIVOT_SCALE_THRESHOLD: %.15lf\n", GAUSSIAN_ELIMINATION_PIVOT_SCALE_THRESHOLD );
printf( " MAX_SIMPLEX_PIVOT_SEARCH_ITERATIONS: %u\n", MAX_SIMPLEX_PIVOT_SEARCH_ITERATIONS );
printf( " CONSTRAINT_VIOLATION_THRESHOLD: %u\n", CONSTRAINT_VIOLATION_THRESHOLD );
Expand Down
4 changes: 4 additions & 0 deletions src/configuration/GlobalConfiguration.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class GlobalConfiguration
// to pick another element.
static const double ACCEPTABLE_SIMPLEX_PIVOT_THRESHOLD;

// If true, column-merging equations are given special treatment and cause columns in the tableau
// to be merged (instead of a new row added).
static const bool USE_COLUMN_MERGING_EQUATIONS;

// If a pivot element in a Gaussian elimination iteration is smaller than this threshold times
// the largest element in the column, the elimination engine will attempt to pick another pivot.
static const double GAUSSIAN_ELIMINATION_PIVOT_SCALE_THRESHOLD;
Expand Down
137 changes: 120 additions & 17 deletions src/engine/Engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -743,27 +743,130 @@ void Engine::applySplit( const PiecewiseLinearCaseSplit &split )
List<Equation> equations = split.getEquations();
for ( auto &equation : equations )
{
unsigned auxVariable = _tableau->addEquation( equation );
_activeEntryStrategy->resizeHook( _tableau );

switch ( equation._type )
/*
In the general case, we just add the new equation to the tableau.
However, we also support a very common case: equations of the form
x1 = x2, which are common, e.g., with ReLUs. For these equations we
may be able to merge two columns of the tableau.
*/
unsigned x1, x2;
bool canMergeColumns =
// Only if the flag is on
GlobalConfiguration::USE_COLUMN_MERGING_EQUATIONS &&
// Only if the equation has the correct form
equation.isVariableMergingEquation( x1, x2 ) &&
// And only if the variables are not out of bounds
( !_tableau->isBasic( x1 ) ||
!_tableau->basicOutOfBounds( _tableau->variableToIndex( x1 ) ) )
&&
( !_tableau->isBasic( x2 ) ||
!_tableau->basicOutOfBounds( _tableau->variableToIndex( x2 ) ) );

if ( canMergeColumns )
{
case Equation::GE:
bounds.append( Tightening( auxVariable, 0.0, Tightening::UB ) );
break;
/*
Special case: x1 and x2 need to be merged.
First, we need to ensure they are both non-basic.
*/
unsigned n = _tableau->getN();
unsigned m = _tableau->getM();

if ( _tableau->isBasic( x1 ) )
{
TableauRow x1Row( n - m );
_tableau->getTableauRow( _tableau->variableToIndex( x1 ), &x1Row );

case Equation::LE:
bounds.append( Tightening( auxVariable, 0.0, Tightening::LB ) );
break;
bool found = false;
unsigned nonBasic;
for ( unsigned i = 0; i < n - m; ++i )
{
if ( !FloatUtils::isZero( x1Row._row[i]._coefficient ) && ( x1Row._row[i]._var != x2 ) )
{
found = true;
nonBasic = x1Row._row[i]._var;
break;
}
}

case Equation::EQ:
bounds.append( Tightening( auxVariable, 0.0, Tightening::LB ) );
bounds.append( Tightening( auxVariable, 0.0, Tightening::UB ) );
break;
if ( !found )
throw ReluplexError( ReluplexError::ENGINE_APPLY_SPLIT_FAILED,
"Could not find a variable to pivot with" );

default:
ASSERT( false );
break;
_tableau->setEnteringVariableIndex( _tableau->variableToIndex( nonBasic ) );
_tableau->setLeavingVariableIndex( _tableau->variableToIndex( x1 ) );

// Make sure the change column and pivot row are up-to-date - strategies
// such as projected steepest edge need these for their internal updates.
_tableau->computeChangeColumn();
_tableau->computePivotRow();

_activeEntryStrategy->prePivotHook( _tableau, false );
_tableau->performDegeneratePivot();
_activeEntryStrategy->prePivotHook( _tableau, false );
}

if ( _tableau->isBasic( x2 ) )
{
TableauRow x2Row( n - m );
_tableau->getTableauRow( _tableau->variableToIndex( x2 ), &x2Row );

bool found = false;
unsigned nonBasic;
for ( unsigned i = 0; i < n - m; ++i )
{
if ( !FloatUtils::isZero( x2Row._row[i]._coefficient ) && ( x2Row._row[i]._var != x1 ) )
{
found = true;
nonBasic = x2Row._row[i]._var;
break;
}
}

if ( !found )
throw ReluplexError( ReluplexError::ENGINE_APPLY_SPLIT_FAILED,
"Could not find a variable to pivot with" );

_tableau->setEnteringVariableIndex( _tableau->variableToIndex( nonBasic ) );
_tableau->setLeavingVariableIndex( _tableau->variableToIndex( x2 ) );

// Make sure the change column and pivot row are up-to-date - strategies
// such as projected steepest edge need these for their internal updates.
_tableau->computeChangeColumn();
_tableau->computePivotRow();

_activeEntryStrategy->prePivotHook( _tableau, false );
_tableau->performDegeneratePivot();
_activeEntryStrategy->prePivotHook( _tableau, false );
}

// Both variables are now non-basic, so we can merge their columns
_tableau->mergeColumns( x1, x2 );
}
else
{
// General case: add a new equation to the tableau
unsigned auxVariable = _tableau->addEquation( equation );
_activeEntryStrategy->resizeHook( _tableau );

switch ( equation._type )
{
case Equation::GE:
bounds.append( Tightening( auxVariable, 0.0, Tightening::UB ) );
break;

case Equation::LE:
bounds.append( Tightening( auxVariable, 0.0, Tightening::LB ) );
break;

case Equation::EQ:
bounds.append( Tightening( auxVariable, 0.0, Tightening::LB ) );
bounds.append( Tightening( auxVariable, 0.0, Tightening::UB ) );
break;

default:
ASSERT( false );
break;
}
}
}

Expand Down
24 changes: 24 additions & 0 deletions src/engine/Equation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,30 @@ void Equation::dump() const
printf( "%.2lf\n", _scalar );
}

bool Equation::isVariableMergingEquation( unsigned &x1, unsigned &x2 ) const
{
if ( _addends.size() != 2 )
return false;

if ( !FloatUtils::isZero( _scalar ) )
return false;

double coefficientOne = _addends.front()._coefficient;
double coefficientTwo = _addends.back()._coefficient;

if ( FloatUtils::isZero( coefficientOne ) || FloatUtils::isZero( coefficientTwo ) )
return false;

if ( FloatUtils::areEqual( coefficientOne, -coefficientTwo ) )
{
x1 = _addends.front()._variable;
x2 = _addends.back()._variable;
return true;
}

return false;
}

//
// Local Variables:
// compile-command: "make -C ../.. "
Expand Down
7 changes: 7 additions & 0 deletions src/engine/Equation.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ class Equation
*/
void updateVariableIndex( unsigned oldVar, unsigned newVar );

/*
Return true iff the variable is a "variable merging equation",
i.e. an equation of the form x = y. If true is returned, x1 and
x2 are the merged variables.
*/
bool isVariableMergingEquation( unsigned &x1, unsigned &x2 ) const;

List<Addend> _addends;
double _scalar;
EquationType _type;
Expand Down
1 change: 1 addition & 0 deletions src/engine/ITableau.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class ITableau
virtual Equation *getBasisEquation( unsigned row ) const = 0;
virtual double *getInverseBasisMatrix() const = 0;
virtual void refreshBasisFactorization() = 0;
virtual void mergeColumns( unsigned x1, unsigned x2 ) = 0;
};

#endif // __ITableau_h__
Expand Down
1 change: 1 addition & 0 deletions src/engine/ReluplexError.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class ReluplexError : public Error
CANNOT_RESTORE_TABLEAU = 12,
FAILURE_TO_ADD_NEW_EQUATION = 13,
RESTORATION_FAILED_TO_REFACTORIZE_BASIS = 14,
ENGINE_APPLY_SPLIT_FAILED = 15,

DEBUGGING_ERROR = 999,
};
Expand Down
48 changes: 46 additions & 2 deletions src/engine/Tableau.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,17 @@ const double *Tableau::getUpperBounds() const

double Tableau::getValue( unsigned variable )
{
/*
If this variable has been merged into another,
we need to be reading the other variable's value
*/
if ( _mergedVariables.exists( variable ) )
variable = _mergedVariables[variable];

// The values of non-basics can be extracted even if the
// assignment is invalid
if ( !_basicVariables.exists( variable ) )
{
// The values of non-basics can be extracted even if the
// assignment is invalid
unsigned index = _variableToIndex[variable];
return _nonBasicAssignment[index];
}
Expand Down Expand Up @@ -1112,6 +1119,9 @@ void Tableau::storeState( TableauState &state ) const

// Store the _boundsValid indicator
state._boundsValid = _boundsValid;

// Store the merged variables
state._mergedVariables = _mergedVariables;
}

void Tableau::restoreState( const TableauState &state )
Expand Down Expand Up @@ -1149,6 +1159,9 @@ void Tableau::restoreState( const TableauState &state )
// Restore the _boundsValid indicator
_boundsValid = state._boundsValid;

// Restore the merged varaibles
_mergedVariables = state._mergedVariables;

computeAssignment();
_costFunctionManager->initialize();
computeCostFunction();
Expand Down Expand Up @@ -1911,6 +1924,8 @@ void Tableau::registerCostFunctionManager( ICostFunctionManager *costFunctionMan
const double *Tableau::getColumnOfBasis( unsigned column ) const
{
ASSERT( column < _m );
ASSERT( !_mergedVariables.exists( _basicIndexToVariable[column] ) );

unsigned variable = _basicIndexToVariable[column];
return _A + ( variable * _m );
}
Expand All @@ -1920,6 +1935,35 @@ void Tableau::refreshBasisFactorization()
_basisFactorization->obtainFreshBasis();
}

void Tableau::mergeColumns( unsigned x1, unsigned x2 )
{
ASSERT( !isBasic( x1 ) );
ASSERT( !isBasic( x2 ) );

/*
If x2 has tighter bounds than x1, adjust the bounds
for x1.
*/
if ( FloatUtils::lt( _upperBounds[x2], _upperBounds[x1] ) )
tightenUpperBound( x1, _upperBounds[x2] );
if ( FloatUtils::gt( _lowerBounds[x2], _lowerBounds[x1] ) )
tightenLowerBound( x1, _lowerBounds[x2] );

/*
Merge column x2 of the constraint matrix into x1
and zero-out column x2
*/
for ( unsigned row = 0; row < _m; ++row )
{
_A[(x1 * _m) + row] += _A[(x2 * _m) + row];
_A[(x2 * _m) + row] = 0.0;
}
_mergedVariables[x2] = x1;

computeAssignment();
computeCostFunction();
}

//
// Local Variables:
// compile-command: "make -C ../.. "
Expand Down
13 changes: 13 additions & 0 deletions src/engine/Tableau.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,12 @@ class Tableau : public ITableau, public IBasisFactorization::BasisColumnOracle
*/
void refreshBasisFactorization();

/*
Merge two columns of the constraint matrix and re-initialize
the tableau.
*/
void mergeColumns( unsigned x1, unsigned x2 );

private:
/*
Variable watchers
Expand Down Expand Up @@ -547,6 +553,13 @@ class Tableau : public ITableau, public IBasisFactorization::BasisColumnOracle
*/
ICostFunctionManager *_costFunctionManager;

/*
_mergedVariables[x] = y means that x = y, and that
variable x has been merged into variable y. So, when
extracting a solution for x, we should read the value of y.
*/
Map<unsigned, unsigned> _mergedVariables;

/*
Free all allocated memory.
*/
Expand Down
8 changes: 8 additions & 0 deletions src/engine/TableauState.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "IBasisFactorization.h"
#include "ITableau.h"
#include "Map.h"
#include "Set.h"

class TableauState
Expand Down Expand Up @@ -104,6 +105,13 @@ class TableauState
Indicator whether the bounds are valid
*/
bool _boundsValid;

/*
_mergedVariables[x] = y means that x = y, and that
variable x has been merged into variable y. So, when
extracting a solution for x, we should read the value of y.
*/
Map<unsigned, unsigned> _mergedVariables;
};

#endif // __TableauState_h__
Expand Down
4 changes: 4 additions & 0 deletions src/engine/tests/MockTableau.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,10 @@ class MockTableau : public ITableau
void refreshBasisFactorization()
{
}

void mergeColumns( unsigned /* x1 */, unsigned /* x2 */ )
{
}
};

#endif // __MockTableau_h__
Expand Down