From f3f4e5cf7453487f7f716aaae20abaf23b8e7638 Mon Sep 17 00:00:00 2001 From: prwang Date: Tue, 11 Jul 2017 01:22:49 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=A0=B7=E4=BE=8B=E9=80=9A=E8=BF=87?= =?UTF-8?q?=E4=B8=80=E5=8D=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + formula.py | 87 +++++++++++++++++++++++++++++++++++------------------- main.py | 33 ++++++++++++++------- search.py | 28 +++--------------- 4 files changed, 85 insertions(+), 64 deletions(-) diff --git a/.gitignore b/.gitignore index 5adb080..e1252cd 100644 --- a/.gitignore +++ b/.gitignore @@ -101,3 +101,4 @@ ENV/ .mypy_cache/ #intellij /.idea/ +in.cnf diff --git a/formula.py b/formula.py index 0370a6c..19d512c 100644 --- a/formula.py +++ b/formula.py @@ -1,4 +1,5 @@ from typing import * +from copy import copy class Clause: @@ -23,53 +24,60 @@ def undef_(self, other: int) -> None: class VarInfo: - rev: Dict[Clause, bool] = {} # role(+-) of its occurrences in each clause + rev: Dict[Clause, bool] # role(+-) of its occurrences in each clause # number of negatively and/or positively occurrences - stat: List[int] = [0, 0] + stat: List[int] n_: int def __init__(self, n: int): self.n_ = n + self.stat = [0, 0] + self.rev = {} class Formula: # the following members are managed by push & pop, used in the recursive process - cnf: Dict[int, Clause] = {} + raw : List[List[int]] + cnf: Dict[int, Clause] var: List[VarInfo] - assignment: Dict[int, Tuple[bool, # value + var_value: Dict[int, Tuple[bool, # value int, # depth - Set[int] # any edge table - ] ] = {} + Optional[Set[int]] ] ] # any edge table + model: Optional[List[int]] changes: List[Tuple[int, bool]] frame: List[int] - depth = 0 + depth: int - one: Set[Clause] # guaranteed to be size 0 after bcp exits FIXME: 递归函数里面使用全局变量小心不同层被改写! - zero: Optional[List[int]] = None #not None during the inflating process(backjumping) + one: Set[Clause] # guaranteed to be size 0 after bcp exits + zero: Optional[List[int]] #not None during the inflating process(backjumping) - def assign(self, var: int, value: bool, + def get_var(self, x) -> VarInfo: return self.var[abs(x) - 1] + + def assign(self, _var: int, value: bool, cause: Optional[Set[int]] = None) -> None: - for cl, val in self.var[var].rev: + self.var_value[_var] = (value, self.depth, + None if cause is None else copy(cause)) + for cl, val in self.get_var(_var).rev.items(): + x = cl.def_(_var * (-1 + 2 * val)) if val == value: - self.before_cl_removed(cl) + self.before_unmount(cl) del self.cnf[cl.i_] else: - x = cl.def_( val) if x == 0: self.one.remove(cl) - self.zero = list(cl.undef) + self.zero = list(cl.defined) elif x == 1: self.one.add(cl) - self.assignment[var] = (value, self.depth, cause) - self.changes.append((var, value)) + self.changes.append((_var, value)) - def unassign(self, var: int, value: bool) -> None: - del self.assignment[var] - for cl, val in reversed(self.var[var].rev): + def unassign(self, _var: int, value: bool) -> None: + for cl, val in self.get_var(_var).rev.items(): # FIXME rev是不可靠的,因为会新产生从句,导致rev变大 + #TODO 新开一个list真的记一下这些修改,这是容易办到的事情 if val == value: self.cnf[cl.i_] = cl - self.after_cl_born(cl) - else: cl.undef_(val) + self.after_mounted(cl) + cl.undef_(_var * (-1 + 2 * val)) + del self.var_value[_var] def push(self) -> None: self.depth += 1 @@ -82,16 +90,17 @@ def pop(self) -> None: self.unassign(x, y) self.depth -= 1 - def after_cl_born(self, cl: Clause) -> None: + def after_mounted(self, cl: Clause) -> None: for x in cl.undef: - self.var[abs(x)].rev[cl] = x > 0 - self.var[abs(x)].stat[x > 0] += 1 + self.get_var(x).rev[cl] = x > 0 + self.get_var(x).stat[x > 0] += 1 + assert self.get_var(x).stat[x > 0] <= 2 #TODO 这个数据调完删掉 - def before_cl_removed(self, cl: Clause) -> None: + def before_unmount(self, cl: Clause) -> None: for x in cl.undef: - self.var[abs(x)].stat[x > 0] -= 1 - assert self.var[abs(x)].stat[x > 0] >= 0 - del self.var[abs(x)].rev[cl] + self.get_var(x).stat[x > 0] -= 1 + assert self.get_var(x).stat[x > 0] >= 0 + del self.get_var(x).rev[cl] def bcp(self) -> bool: while (self.zero is not None) or len(self.one): @@ -107,9 +116,27 @@ def bcp(self) -> bool: def add_clause(self, cl) -> None: i = len(self.cnf) tp = self.cnf[i] = Clause(cl, i) - self.after_cl_born(tp) + self.after_mounted(tp) + def __init__(self, n: int, cnf1: List[List[int]]): - self.var = [VarInfo(i) for i in range(n)] + self.var = [VarInfo(1 + i) for i in range(n)] + self.raw = cnf1 + self.depth = 0 + self.cnf = {} + self.var_value = {} + self.model= None + self.changes = [] + self.frame = [] + self.one = set() + self.zero = None for cl in cnf1: self.add_clause(cl) + def validate(self) -> None: + n = len(self.var) + self.model = { (i + 1) : False for i in range(n)} + for x, (y, p1, p2) in self.var_value.items(): + self.model[x] = y + assert all([any([self.model[abs(j)] == (j > 0) for j in i]) for i in self.raw]) + + diff --git a/main.py b/main.py index 115a63e..1e08df1 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,8 @@ from typing import * +import sys import itertools as It -from formula import Formula +from search import CdclEngine + def parse() -> Tuple[int, List]: tokens = iter([]) @@ -8,28 +10,30 @@ def parse() -> Tuple[int, List]: while True: line1 = list(filter(None, input().strip().split())) if len(line1) and line1[0] != 'c': - tokens = It.chain(tokens, input().split()) + tokens = It.chain(tokens, line1) except EOFError: pass clauses = [] try: tokens = list(tokens) - assert(not (len(tokens) < 4 or tokens[0] != 'p' or tokens[1] != 'cnf')) + assert len(tokens) >= 4 and tokens[0] == 'p' and tokens[1] == 'cnf' n, m = map(int, tokens[2:4]) cc = [] for num in map(int, tokens[4:]): if num == 0: cc = list(set(cc)) if len(cc): - cc.sort(key = abs) - x = len(cc) - 1; i = 0 - while i < x: - if abs(cc[x]) == abs(cc[x + 1]): break - if i != x: clauses.append(cc) + cc.sort(key=abs) + N = len(cc) - 1; i = 0 + while i < N: + if abs(cc[i]) == abs(cc[i + 1]): break + i += 1 + if i == N: + clauses.append(cc) cc = [] else: cc.append(num) if len(cc): clauses.append(cc) - assert(len(cc) == m) + assert(len(clauses) == m) return n, clauses except (ValueError, AssertionError) as e: print("invalid cnf format") @@ -38,8 +42,17 @@ def parse() -> Tuple[int, List]: def main() -> None: n, M = parse() - fm = Formula(n, M) + fm = CdclEngine(n, M) + if fm.solve(): + print('sat') + for i in range(1, n + 1): + print("var_%d = %d"%(i, fm.model[i])) + else: + print('unsat') + if __name__ == "__main__": + sys.stdin = open('in.cnf', 'r') main() + diff --git a/search.py b/search.py index cbb33c4..05e5eec 100644 --- a/search.py +++ b/search.py @@ -1,24 +1,6 @@ from typing import * from formula import Formula -#FIXME 检查所有的abs -#FIXME 检查所有的not 注意空表是falsey,不像javascript - -''' - def SAT(Formula): - while true: - Formula1 = BCP(Formula) - Formula2 = SET_PURE_TRUE(Formula1) - if Formula2 == Formula1: - break - if Formula == true: - return true - elif Formula == false: - return false - else: - x = CHOOSE_VAR(Formula) - return Formula[true/x] or Formula[false/x] - ''' class CdclEngine(Formula): def __init__(self, n: int, cnf1: List[List[int]]): @@ -27,18 +9,17 @@ def __init__(self, n: int, cnf1: List[List[int]]): def inflate_cause(self, cause: Iterable[int]): for i in cause: # 找到低阶项或者同阶自由变量,特别的是同阶决定变量要yieldfrom 边表 i1 = abs(i) - val, dep, edg = self.assignment[i1] + val, dep, edg = self.var_value[i1] if val: i1 = -i1 # if it's been an positive assignment before, # now it's to be forced false, and vice versa if dep < self.depth or edg is None: yield (i1, edg is None) else: yield from self.inflate_cause(edg) -#TODO 这里先乱写一个能用的,调通了再改 def decide(self) -> bool: #True if already satisfied dc = (0, 0, False) for p in self.var: - if p.n_ not in self.assignment and (p.stat[0] + p.stat[1]): + if p.n_ not in self.var_value and (p.stat[0] + p.stat[1]): if not p.stat[0]: # pure self.assign(p.n_, True, None) return False @@ -62,21 +43,20 @@ def step(self) -> bool: #throws IndexError elif self.zero is None: return False assert(self.zero) newlst: List[Tuple[int, bool]] = list(self.inflate_cause(self.zero)) - # 这里会出现x or -x吗 ? self.zero = [x[0] for x in newlst] if any(x[1] for x in newlst): self.add_clause(self.zero) self.zero = None self.pop() return False + def solve(self) -> bool: try: while not self.step(): pass - + self.validate() return True except IndexError: return False - #pop两次当且仅当所有的返回项都是低阶项,TODO 反复思考这里的正确性 # 哪一步调push(), 哪一步pop() # 向下走,push(), 向上走pop(),走向sibling: pop() & push()