Skip to content

Commit

Permalink
Max Constraint Fix implemented by Duligur (#96)
Browse files Browse the repository at this point in the history
* Copied some files

* Simple check for obsolete if phase is fixed

* Better obsolete()
  • Loading branch information
ShantanuThakoor authored Oct 4, 2018
1 parent 0b0c927 commit 2881fba
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 182 deletions.
37 changes: 17 additions & 20 deletions maraboupy/MarabouNetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self):
Constructs a MarabouNetwork object and calls function to initialize
"""
self.clear()

def clear(self):
"""
Reset values to represent empty network
Expand All @@ -20,29 +20,30 @@ def clear(self):
self.equList = []
self.reluList = []
self.maxList = []
self.varsParticipatingInConstraints = set()
self.lowerBounds = dict()
self.upperBounds = dict()
self.inputVars = []
self.outputVars = np.array([])

def getNewVariable(self):
"""
Function to request allocation of new variable
Returns:
varnum: (int) representing new variable
"""
self.numVars += 1
return self.numVars - 1

def addEquation(self, x):
"""
Function to add new equation to the network
Arguments:
x: (MarabouUtils.Equation) representing new equation
"""
self.equList += [x]

def setLowerBound(self, x, v):
"""
Function to set lower bound for variable
Expand All @@ -51,7 +52,7 @@ def setLowerBound(self, x, v):
v: (float) value representing lower bound
"""
self.lowerBounds[x]=v

def setUpperBound(self, x, v):
"""
Function to set upper bound for variable
Expand All @@ -69,7 +70,9 @@ def addRelu(self, v1, v2):
v2: (int) variable representing output of Relu
"""
self.reluList += [(v1, v2)]

self.varsParticipatingInConstraints.add(v1)
self.varsParticipatingInConstraints.add(v2)

def addMaxConstraint(self, elements, v):
"""
Function to add a new Max constraint
Expand All @@ -78,6 +81,9 @@ def addMaxConstraint(self, elements, v):
v: (int) variable representing output of max constraint
"""
self.maxList += [(elements, v)]
self.varsParticipatingInConstraints.add(v)
for i in elements:
self.varsParticipatingInConstraints.add(i)

def lowerBoundExists(self, x):
"""
Expand All @@ -102,16 +108,7 @@ def participatesInPLConstraint(self, x):
x: (int) variable to check
"""
# ReLUs
if self.reluList:
fs, bs = zip(*self.reluList)
if x in fs or x in bs:
return True

# Max constraints
for elems, var in self.maxList:
if x in elems or x==var:
return True
return False
return x in self.varsParticipatingInConstraints

def getMarabouQuery(self):
"""
Expand All @@ -121,7 +118,7 @@ def getMarabouQuery(self):
"""
ipq = MarabouCore.InputQuery()
ipq.setNumberOfVariables(self.numVars)

for e in self.equList:
eq = MarabouCore.Equation(e.EquationType)
for (c, v) in e.addendList:
Expand All @@ -133,7 +130,7 @@ def getMarabouQuery(self):
for r in self.reluList:
assert r[1] < self.numVars and r[0] < self.numVars
MarabouCore.addReluConstraint(ipq, r[0], r[1])

for m in self.maxList:
assert m[1] < self.numVars
for e in m[0]:
Expand Down Expand Up @@ -190,7 +187,7 @@ def evaluateWithMarabou(self, inputValues, filename="evaluateWithMarabou.log"):
print("Evaluating with Marabou\n")
inputVars = self.inputVars # list of numpy arrays
outputVars = self.outputVars

inputDict = dict()
inputVarList = np.concatenate(inputVars, axis=-1).ravel()
inputValList = np.concatenate(inputValues).ravel()
Expand Down
Loading

0 comments on commit 2881fba

Please sign in to comment.