Skip to content

Commit

Permalink
DnC solver (#113)
Browse files Browse the repository at this point in the history
* DnC solver

* modify readme

* add some comment to the main method

* add comments
  • Loading branch information
wu-haoze authored and guykatzz committed Feb 5, 2019
1 parent c3f33a4 commit 085138e
Show file tree
Hide file tree
Showing 8 changed files with 978 additions and 1 deletion.
84 changes: 84 additions & 0 deletions maraboupy/AcasNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
'''
/* ******************* */
/*! \file AcasNet.py
** \verbatim
** Top contributors (to current version):
** Haoze Wu
** This file is part of the Marabou project.
** Copyright (c) 2017-2019 by the authors listed in the file AUTHORS
** in the top-level source directory) and their institutional affiliations.
** All rights reserved. See the file COPYING in the top-level source
** directory for licensing information.\endverbatim
**
** \brief Parser class that uses native Marabou's parser
**
** [[ Add lengthier description here ]]
**/
'''

from . import MarabouCore
import numpy as np

class AcasNet:
"""
Class representing AcasXU Marabou network
"""
def __init__(self, network_path, property_path="", lbs=None, ubs=None):
"""
Constructs a MarabouNetwork object and calls function to initialize
"""
self.network_path = network_path
self.getMarabouQuery(property_path)
if not(lbs is None and ubs is None):
assert(len(lbs) == len(ubs))
assert(len(lbs) == self.ipq.getNumInputVariables())
for i in range(len(lbs)):
self.setInputLowerBound(i, lbs[i])
self.setInputUpperBound(i, ubs[i])



def getMarabouQuery(self, property_path=""):
self.ipq = MarabouCore.InputQuery()
MarabouCore.createInputQuery(self.ipq, self.network_path, property_path)

def setInputLowerBound(self, input_ind, scalar):
assert(input_ind < self.ipq.getNumInputVariables())
variable = self.ipq.inputVariableByIndex(input_ind)
if self.ipq.getLowerBound(variable) < scalar:
self.ipq.setLowerBound(variable, scalar)

def setInputUpperBound(self, input_ind, scalar):
assert(input_ind < self.ipq.getNumInputVariables())
variable = self.ipq.inputVariableByIndex(input_ind)
if self.ipq.getUpperBound(variable) > scalar:
self.ipq.setUpperBound(variable, scalar)

def getInputRanges(self):
inputMins = []
inputMaxs = []
for input_ind in range(self.ipq.getNumInputVariables()):
variable = self.ipq.inputVariableByIndex(input_ind)
inputMins.append(self.ipq.getLowerBound(variable))
inputMaxs.append(self.ipq.getUpperBound(variable))
return np.array(inputMins), np.array(inputMaxs)


def solve(self, filename="", timeout=0):
"""
Function to solve query represented by this network
Arguments:
filename: (string) path to redirect output to
Returns:
vals: (dict: int->float) empty if UNSAT, else the
satisfying assignment to the input and output variables
stats: (Statistics) the Statistics object as defined in Marabou
"""
vals, stats = MarabouCore.solve(self.ipq, filename, timeout)
assignment = []
if len(vals) > 0:
for i in range(self.ipq.getNumInputVariables()):
assignment.append("input {} = {}".format(i, vals[self.ipq.inputVariableByIndex(i)]))
for i in range(self.ipq.getNumOutputVariables()):
assignment.append("Output {} = {}".format(i, vals[self.ipq.outputVariableByIndex(i)]))
return [assignment, stats]
89 changes: 89 additions & 0 deletions maraboupy/DnC.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
'''
/* ******************* */
/*! \file DnCParallelSolver.py
** \verbatim
** Top contributors (to current version):
** Haoze Wu
** This file is part of the Marabou project.
** Copyright (c) 2017-2019 by the authors listed in the file AUTHORS
** in the top-level source directory) and their institutional affiliations.
** All rights reserved. See the file COPYING in the top-level source
** directory for licensing information.\endverbatim
**
** \brief Main method that calls the divide-and-conquer solver
**
** [[ Add lengthier description here ]]
**/
'''

"""
"""
from maraboupy import DnCSolver
from maraboupy import Options

import numpy as np
from multiprocessing import Process, Pipe
import os

def main():
"""
The main method
Checking a property (property.txt) on a network (network.nnet) and
save the results to a summary file (summary.txt)
"python3 DnC.py -n network -q property.txt --summary-file=summary.txt"
Call "python3 DnC.py --help" to see other options
"""

options, args = Options.create_parser().parse_args()

try:
property_path = options.query
except:
assert(os.path.isfile(property_path))
exit()
try:
network_name = options.network_name
assert(os.path.isfile(network_name + ".pb"))
assert(os.path.isfile(network_name + ".nnet"))
except:
print("Fail to import network!")
exit()


# Get arguments from input
(num_workers, input_name,
initial_splits, online_split,
init_to, to_factor, strategy,
seed, log_file, summary_file) = Options.get_constructor_arguments(options)

solver = DnCSolver.DnCSolver(network_name, property_path, num_workers,
initial_splits, online_split, init_to, to_factor,
strategy, input_name, seed, log_file)


try:
# Initial split of the input region
parent_conn, child_conn = Pipe()
p = Process(target=getSubProblems, args=(solver, child_conn))
p.start()
sub_queries = parent_conn.recv()
p.join()

# Solve the created subqueries
solver.solve(sub_queries)
except KeyboardInterrupt:
solver.write_summary_file(summary_file, True)
else:
solver.write_summary_file(summary_file, False)
return

def getSubProblems(solver, conn):
sub_queries = solver.initial_split()
conn.send(sub_queries)
conn.close()


if __name__ == "__main__":
main()
Loading

0 comments on commit 085138e

Please sign in to comment.