-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* DnC solver * modify readme * add some comment to the main method * add comments
- Loading branch information
Showing
8 changed files
with
978 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
''' | ||
/* ******************* */ | ||
/*! \file AcasNet.py | ||
** \verbatim | ||
** Top contributors (to current version): | ||
** Haoze Wu | ||
** This file is part of the Marabou project. | ||
** Copyright (c) 2017-2019 by the authors listed in the file AUTHORS | ||
** in the top-level source directory) and their institutional affiliations. | ||
** All rights reserved. See the file COPYING in the top-level source | ||
** directory for licensing information.\endverbatim | ||
** | ||
** \brief Parser class that uses native Marabou's parser | ||
** | ||
** [[ Add lengthier description here ]] | ||
**/ | ||
''' | ||
|
||
from . import MarabouCore | ||
import numpy as np | ||
|
||
class AcasNet: | ||
""" | ||
Class representing AcasXU Marabou network | ||
""" | ||
def __init__(self, network_path, property_path="", lbs=None, ubs=None): | ||
""" | ||
Constructs a MarabouNetwork object and calls function to initialize | ||
""" | ||
self.network_path = network_path | ||
self.getMarabouQuery(property_path) | ||
if not(lbs is None and ubs is None): | ||
assert(len(lbs) == len(ubs)) | ||
assert(len(lbs) == self.ipq.getNumInputVariables()) | ||
for i in range(len(lbs)): | ||
self.setInputLowerBound(i, lbs[i]) | ||
self.setInputUpperBound(i, ubs[i]) | ||
|
||
|
||
|
||
def getMarabouQuery(self, property_path=""): | ||
self.ipq = MarabouCore.InputQuery() | ||
MarabouCore.createInputQuery(self.ipq, self.network_path, property_path) | ||
|
||
def setInputLowerBound(self, input_ind, scalar): | ||
assert(input_ind < self.ipq.getNumInputVariables()) | ||
variable = self.ipq.inputVariableByIndex(input_ind) | ||
if self.ipq.getLowerBound(variable) < scalar: | ||
self.ipq.setLowerBound(variable, scalar) | ||
|
||
def setInputUpperBound(self, input_ind, scalar): | ||
assert(input_ind < self.ipq.getNumInputVariables()) | ||
variable = self.ipq.inputVariableByIndex(input_ind) | ||
if self.ipq.getUpperBound(variable) > scalar: | ||
self.ipq.setUpperBound(variable, scalar) | ||
|
||
def getInputRanges(self): | ||
inputMins = [] | ||
inputMaxs = [] | ||
for input_ind in range(self.ipq.getNumInputVariables()): | ||
variable = self.ipq.inputVariableByIndex(input_ind) | ||
inputMins.append(self.ipq.getLowerBound(variable)) | ||
inputMaxs.append(self.ipq.getUpperBound(variable)) | ||
return np.array(inputMins), np.array(inputMaxs) | ||
|
||
|
||
def solve(self, filename="", timeout=0): | ||
""" | ||
Function to solve query represented by this network | ||
Arguments: | ||
filename: (string) path to redirect output to | ||
Returns: | ||
vals: (dict: int->float) empty if UNSAT, else the | ||
satisfying assignment to the input and output variables | ||
stats: (Statistics) the Statistics object as defined in Marabou | ||
""" | ||
vals, stats = MarabouCore.solve(self.ipq, filename, timeout) | ||
assignment = [] | ||
if len(vals) > 0: | ||
for i in range(self.ipq.getNumInputVariables()): | ||
assignment.append("input {} = {}".format(i, vals[self.ipq.inputVariableByIndex(i)])) | ||
for i in range(self.ipq.getNumOutputVariables()): | ||
assignment.append("Output {} = {}".format(i, vals[self.ipq.outputVariableByIndex(i)])) | ||
return [assignment, stats] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
''' | ||
/* ******************* */ | ||
/*! \file DnCParallelSolver.py | ||
** \verbatim | ||
** Top contributors (to current version): | ||
** Haoze Wu | ||
** This file is part of the Marabou project. | ||
** Copyright (c) 2017-2019 by the authors listed in the file AUTHORS | ||
** in the top-level source directory) and their institutional affiliations. | ||
** All rights reserved. See the file COPYING in the top-level source | ||
** directory for licensing information.\endverbatim | ||
** | ||
** \brief Main method that calls the divide-and-conquer solver | ||
** | ||
** [[ Add lengthier description here ]] | ||
**/ | ||
''' | ||
|
||
""" | ||
""" | ||
from maraboupy import DnCSolver | ||
from maraboupy import Options | ||
|
||
import numpy as np | ||
from multiprocessing import Process, Pipe | ||
import os | ||
|
||
def main(): | ||
""" | ||
The main method | ||
Checking a property (property.txt) on a network (network.nnet) and | ||
save the results to a summary file (summary.txt) | ||
"python3 DnC.py -n network -q property.txt --summary-file=summary.txt" | ||
Call "python3 DnC.py --help" to see other options | ||
""" | ||
|
||
options, args = Options.create_parser().parse_args() | ||
|
||
try: | ||
property_path = options.query | ||
except: | ||
assert(os.path.isfile(property_path)) | ||
exit() | ||
try: | ||
network_name = options.network_name | ||
assert(os.path.isfile(network_name + ".pb")) | ||
assert(os.path.isfile(network_name + ".nnet")) | ||
except: | ||
print("Fail to import network!") | ||
exit() | ||
|
||
|
||
# Get arguments from input | ||
(num_workers, input_name, | ||
initial_splits, online_split, | ||
init_to, to_factor, strategy, | ||
seed, log_file, summary_file) = Options.get_constructor_arguments(options) | ||
|
||
solver = DnCSolver.DnCSolver(network_name, property_path, num_workers, | ||
initial_splits, online_split, init_to, to_factor, | ||
strategy, input_name, seed, log_file) | ||
|
||
|
||
try: | ||
# Initial split of the input region | ||
parent_conn, child_conn = Pipe() | ||
p = Process(target=getSubProblems, args=(solver, child_conn)) | ||
p.start() | ||
sub_queries = parent_conn.recv() | ||
p.join() | ||
|
||
# Solve the created subqueries | ||
solver.solve(sub_queries) | ||
except KeyboardInterrupt: | ||
solver.write_summary_file(summary_file, True) | ||
else: | ||
solver.write_summary_file(summary_file, False) | ||
return | ||
|
||
def getSubProblems(solver, conn): | ||
sub_queries = solver.initial_split() | ||
conn.send(sub_queries) | ||
conn.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.