File ‹~~/src/Tools/Argo/argo_cdcl.ML›
signature ARGO_CDCL =
sig
type 'a explain = Argo_Lit.literal -> 'a -> Argo_Cls.clause * 'a
type context
val context: context
val assignment_of: context -> Argo_Lit.literal -> bool option
val add_atom: Argo_Term.term -> context -> context
val add_axiom: Argo_Cls.clause -> context -> int * context
val assume: 'a explain -> Argo_Lit.literal -> context -> 'a ->
Argo_Cls.clause option * context * 'a
val propagate: context -> Argo_Common.literal Argo_Common.implied * context
val decide: context -> context option
val analyze: 'a explain -> Argo_Cls.clause -> context -> 'a -> int * context * 'a
val restart: context -> int * context
end
structure Argo_Cdcl: ARGO_CDCL =
struct
type 'a explain = Argo_Lit.literal -> 'a -> Argo_Cls.clause * 'a
datatype reason =
Level0 of Argo_Proof.proof |
Decided of int * int * (bool * reason) Argo_Termtab.table |
Implied of int * int * (Argo_Lit.literal * reason) list * Argo_Proof.proof |
External of int
fun level_of (Level0 _) = 0
| level_of (Decided (l, _, _)) = l
| level_of (Implied (l, _, _, _)) = l
| level_of (External l) = l
type justified = Argo_Lit.literal * reason
type watches = Argo_Cls.clause list * Argo_Cls.clause list
fun get_watches wts t = Argo_Termtab.lookup wts t
fun map_watches f t wts = Argo_Termtab.map_default (t, ([], [])) f wts
fun map_lit_watches f (Argo_Lit.Pos t) = map_watches (apsnd f) t
| map_lit_watches f (Argo_Lit.Neg t) = map_watches (apfst f) t
fun watches_of wts (Argo_Lit.Pos t) = (case get_watches wts t of SOME (ws, _) => ws | NONE => [])
| watches_of wts (Argo_Lit.Neg t) = (case get_watches wts t of SOME (_, ws) => ws | NONE => [])
fun attach cls lit = map_lit_watches (cons cls) lit
fun detach cls lit = map_lit_watches (remove Argo_Cls.eq_clause cls) lit
fun raw_val_of vals lit = Argo_Termtab.lookup vals (Argo_Lit.term_of lit)
fun val_of vals (Argo_Lit.Pos t) = Argo_Termtab.lookup vals t
| val_of vals (Argo_Lit.Neg t) = Option.map (apfst not) (Argo_Termtab.lookup vals t)
fun value_of vals (Argo_Lit.Pos t) = Option.map fst (Argo_Termtab.lookup vals t)
| value_of vals (Argo_Lit.Neg t) = Option.map (not o fst) (Argo_Termtab.lookup vals t)
fun justified vals lit = Option.map (pair lit o snd) (raw_val_of vals lit)
fun the_reason_of vals lit = snd (the (raw_val_of vals lit))
fun assign (Argo_Lit.Pos t) r = Argo_Termtab.update (t, (true, r))
| assign (Argo_Lit.Neg t) r = Argo_Termtab.update (t, (false, r))
type trail = int * justified list
type context = {
units: Argo_Common.literal list,
level: int,
trail: int * justified list,
vals: (bool * reason) Argo_Termtab.table,
wts: watches Argo_Termtab.table,
heap: Argo_Heap.heap,
clss: Argo_Cls.table,
prf: Argo_Proof.context}
fun mk_context units level trail vals wts heap clss prf: context =
{units=units, level=level, trail=trail, vals=vals, wts=wts, heap=heap, clss=clss, prf=prf}
val context =
mk_context [] 0 (0, []) Argo_Termtab.empty Argo_Termtab.empty Argo_Heap.heap
Argo_Cls.table Argo_Proof.cdcl_context
fun drop_levels n (Decided (l, h, vals)) trail heap =
if l = n + 1 then ((h, trail), vals, heap) else drop_literal n trail heap
| drop_levels n _ tr heap = drop_literal n tr heap
and drop_literal n ((lit, r) :: trail) heap = drop_levels n r trail (Argo_Heap.insert lit heap)
| drop_literal _ [] _ = raise Fail "bad trail"
fun backjump_to new_level (cx as {level, trail=(_, tr), wts, heap, clss, prf, ...}: context) =
if new_level >= level then (0, cx)
else
let val (trail, vals, heap) = drop_literal (Integer.max 0 new_level) tr heap
in (level - new_level, mk_context [] new_level trail vals wts heap clss prf) end
fun tag_clause (lits, p) prf = Argo_Proof.mk_clause lits p prf |>> pair lits
fun level0_unit_proof (lit, Level0 p') (p, prf) = Argo_Proof.mk_unit_res lit p p' prf
| level0_unit_proof _ _ = raise Fail "bad reason"
fun level0_unit_proofs lrs p prf = fold level0_unit_proof lrs (p, prf)
fun unsat ({vals, prf, ...}: context) (lits, p) =
let val lrs = map (fn lit => (lit, the_reason_of vals lit)) lits
in Argo_Proof.unsat (fst (level0_unit_proofs lrs p prf)) end
fun push lit p reason prf ({units, level, trail=(h, tr), vals, wts, heap, clss, ...}: context) =
let val vals = assign lit reason vals
in mk_context ((lit, p) :: units) level (h + 1, (lit, reason) :: tr) vals wts heap clss prf end
fun push_level0 lit p lrs (cx as {prf, ...}: context) =
let val (p, prf) = level0_unit_proofs lrs p prf
in push lit (SOME p) (Level0 p) prf cx end
fun push_implied lit p lrs (cx as {level, trail=(h, _), prf, ...}: context) =
if level > 0 then push lit NONE (Implied (level, h, lrs, p)) prf cx
else push_level0 lit p lrs cx
fun push_decided lit (cx as {level, trail=(h, _), vals, prf, ...}: context) =
push lit NONE (Decided (level, h, vals)) prf cx
fun assignment_of ({vals, ...}: context) = value_of vals
fun replace_watches old new cls ({units, level, trail, vals, wts, heap, clss, prf}: context) =
mk_context units level trail vals (attach cls new (detach cls old wts)) heap clss prf
fun as_clause cls ({units, level, trail, vals, wts, heap, clss, prf}: context) =
let val (cls, prf) = tag_clause cls prf
in (cls, mk_context units level trail vals wts heap clss prf) end
fun note_watches ([_, _], _) _ clss = clss
| note_watches cls lp clss = Argo_Cls.put_watches cls lp clss
fun attach_clause lit1 lit2 (cls as (lits, _)) cx =
let
val {units, level, trail, vals, wts, heap, clss, prf}: context = cx
val wts = attach cls lit1 (attach cls lit2 wts)
val clss = note_watches cls (lit1, lit2) clss
in mk_context units level trail vals wts (fold Argo_Heap.count lits heap) clss prf end
fun change_watches _ (false, _, _) cx = cx
| change_watches cls (true, l1, l2) ({units, level, trail, vals, wts, heap, clss, prf}: context) =
mk_context units level trail vals wts heap (Argo_Cls.put_watches cls (l1, l2) clss) prf
fun add_asserting lit lit' (cls as (_, p)) lrs cx =
attach_clause lit lit' cls (push_implied lit p lrs cx)
fun learn_clause _ ([lit], p) cx = backjump_to 0 cx ||> push_level0 lit p []
| learn_clause lrs (cls as (lits, _)) cx =
let
fun max_level (l, r) (ll as (_, lvl)) = if level_of r > lvl then (l, level_of r) else ll
val (lit, lvl) = fold max_level lrs (hd lits, 0)
in backjump_to lvl cx ||> add_asserting (hd lits) lit cls lrs end
fun min lit i NONE = SOME (lit, i)
| min lit i (SOME (lj as (_, j))) = SOME (if i < j then (lit, i) else lj)
fun level_ord ((_, r1), (_, r2)) = int_ord (level_of r2, level_of r1)
fun add_max lr lrs = Ord_List.insert level_ord lr lrs
fun part [] [] t us fs = (t, us, fs)
| part (NONE :: vs) (l :: ls) t us fs = part vs ls t (l :: us) fs
| part (SOME (true, r) :: vs) (l :: ls) t us fs = part vs ls (min l (level_of r) t) us fs
| part (SOME (false, r) :: vs) (l :: ls) t us fs = part vs ls t us (add_max (l, r) fs)
| part _ _ _ _ _ = raise Fail "mismatch between values and literals"
fun backjump_add (lit, r) (lit', r') cls lrs cx =
let
val add =
if level_of r = level_of r' then attach_clause lit lit' cls
else add_asserting lit lit' cls lrs
in backjump_to (level_of r - 1) cx ||> add end
fun analyze_axiom vs (cls as (lits, p), cx) =
(case part vs lits NONE [] [] of
(SOME (lit, lvl), [], []) =>
if lvl > 0 then backjump_to 0 cx ||> push_implied lit p [] else (0, cx)
| (SOME (lit, lvl), [], (lit', _) :: _) => (0, cx |> (lvl > 0) ? attach_clause lit lit' cls)
| (SOME (lit, lvl), lit' :: _, _) => (0, cx |> (lvl > 0) ? attach_clause lit lit' cls)
| (NONE, [], (_, Level0 _) :: _) => unsat cx cls
| (NONE, [], [(lit, _)]) => backjump_to 0 cx ||> push_implied lit p []
| (NONE, [], lrs as (lr :: lr' :: _)) => backjump_add lr lr' cls lrs cx
| (NONE, [lit], []) => backjump_to 0 cx ||> push_implied lit p []
| (NONE, [lit], lrs as (lit', _) :: _) => (0, add_asserting lit lit' cls lrs cx)
| (NONE, lit1 :: lit2 :: _, _) => (0, attach_clause lit1 lit2 cls cx)
| _ => raise Fail "bad clause")
fun add_atom t ({units, level, trail, vals, wts, heap, clss, prf}: context) =
let val heap = Argo_Heap.insert (Argo_Lit.Pos t) heap
in mk_context units level trail vals wts heap clss prf end
fun add_axiom ([], p) _ = Argo_Proof.unsat p
| add_axiom (cls as (lits, _)) (cx as {vals, ...}: context) =
if has_duplicates Argo_Lit.eq_lit lits then raise Fail "clause with duplicate literals"
else if has_duplicates Argo_Lit.dual_lit lits then (0, cx)
else analyze_axiom (map (val_of vals) lits) (as_clause cls cx)
fun assume explain lit (cx as {level, vals, prf, ...}: context) x =
(case value_of vals lit of
SOME true => (NONE, cx, x)
| SOME false =>
let val (cls, x) = explain lit x
in if level = 0 then unsat cx cls else (SOME cls, cx, x) end
| NONE =>
if level = 0 then
let val ((lits, p), x) = explain lit x
in (NONE, push_level0 lit p (map_filter (justified vals) lits) cx, x) end
else (NONE, push lit NONE (External level) prf cx, x))
exception CONFLICT of Argo_Cls.clause * context
fun order_lits_by lit (l1, l2) =
if Argo_Lit.eq_id (l1, lit) then (true, l2, l1) else (false, l1, l2)
fun prop_binary (_, implied_lit, other_lit) (cls as (_, p)) (cx as {level, vals, ...}: context) =
(case value_of vals implied_lit of
NONE => push_implied implied_lit p [(other_lit, the_reason_of vals other_lit)] cx
| SOME true => cx
| SOME false => if level = 0 then unsat cx cls else raise CONFLICT (cls, cx))
datatype next = Lit of Argo_Lit.literal | None of justified list
fun with_non_false f l (SOME (false, r)) lrs = f ((l, r) :: lrs)
| with_non_false _ l _ _ = Lit l
fun first_non_false _ _ [] lrs = None lrs
| first_non_false vals lit (l :: ls) lrs =
if Argo_Lit.eq_lit (l, lit) then first_non_false vals lit ls lrs
else with_non_false (first_non_false vals lit ls) l (val_of vals l) lrs
fun prop_nary (lp as (_, lit1, lit2)) (cls as (lits, p)) (cx as {level, vals, ...}: context) =
let val v = value_of vals lit1
in
if v = SOME true then change_watches cls lp cx
else
(case first_non_false vals lit1 lits [] of
Lit lit2' => change_watches cls (true, lit1, lit2') (replace_watches lit2 lit2' cls cx)
| None lrs =>
if v = NONE then push_implied lit1 p lrs (change_watches cls lp cx)
else if level = 0 then unsat cx cls
else raise CONFLICT (cls, change_watches cls lp cx))
end
fun prop_cls lit (cls as ([l1, l2], _)) cx = prop_binary (order_lits_by lit (l1, l2)) cls cx
| prop_cls lit cls (cx as {clss, ...}: context) =
prop_nary (order_lits_by lit (Argo_Cls.get_watches clss cls)) cls cx
fun prop_lit (lp as (lit, _)) (lps, cx as {wts, ...}: context) =
(lp :: lps, fold (prop_cls lit) (watches_of wts lit) cx)
fun prop lps (cx as {units=[], ...}: context) = (Argo_Common.Implied (rev lps), cx)
| prop lps ({units, level, trail, vals, wts, heap, clss, prf}: context) =
fold_rev prop_lit units (lps, mk_context [] level trail vals wts heap clss prf) |-> prop
fun propagate cx = prop [] cx
handle CONFLICT (cls, cx) => (Argo_Common.Conflict cls, cx)
fun decide ({units, level, trail, vals, wts, heap, clss, prf}: context) =
let
fun check NONE = NONE
| check (SOME (lit, heap)) =
if Argo_Termtab.defined vals (Argo_Lit.term_of lit) then check (Argo_Heap.extract heap)
else SOME (push_decided lit (mk_context units (level + 1) trail vals wts heap clss prf))
in check (Argo_Heap.extract heap) end
exception ESSENTIAL of unit
fun history_ord ((h1, lit1, _), (h2, lit2, _)) =
if h1 < 0 andalso h2 < 0 then int_ord (apply2 Argo_Lit.signed_id_of (lit1, lit2))
else int_ord (h2, h1)
fun rec_redundant stop (lit, Implied (lvl, h, lrs, p)) lps =
if stop lit lvl then lps
else fold (rec_redundant stop) lrs ((h, lit, p) :: lps)
| rec_redundant stop (lit, Decided (lvl, _, _)) lps =
if stop lit lvl then lps
else raise ESSENTIAL ()
| rec_redundant _ (lit, Level0 p) lps = ((~1, lit, p) :: lps)
| rec_redundant _ _ _ = raise ESSENTIAL ()
fun redundant stop (lr as (lit, Implied (_, h, lrs, p))) (lps, essential_lrs) = (
(fold (rec_redundant stop) lrs ((h, lit, p) :: lps), essential_lrs)
handle ESSENTIAL () => (lps, lr :: essential_lrs))
| redundant _ lr (lps, essential_lrs) = (lps, lr :: essential_lrs)
fun resolve_step (_, l, p') (p, prf) = Argo_Proof.mk_unit_res l p p' prf
fun reduce lrs p prf =
let
val lits = map fst lrs
val levels = fold (insert (op =) o level_of o snd) lrs []
fun stop lit level =
if member Argo_Lit.eq_lit lits lit then true
else if member (op =) levels level then false
else raise ESSENTIAL ()
val (lps, lrs) = fold (redundant stop) lrs ([], [])
in (lrs, fold resolve_step (sort_distinct history_ord lps) (p, prf)) end
fun unmark lit ms = remove Argo_Lit.eq_id lit ms
fun marked ms lit = member Argo_Lit.eq_id ms lit
fun justification_for _ _ _ (Implied (_, _, lrs, p)) x = (lrs, p, x)
| justification_for explain vals lit (External _) x =
let val ((lits, p), x) = explain lit x
in (map_filter (justified vals) lits, p, x) end
| justification_for _ _ _ _ _ = raise Fail "bad reason"
fun first_lit pred ((lr as (lit, _)) :: lrs) = if pred lit then (lr, lrs) else first_lit pred lrs
| first_lit _ _ = raise Empty
fun analyze explain cls (cx as {level, trail, vals, wts, heap, clss, prf, ...}: context) x =
let
fun from_clause [] trail ms lrs h p prf x =
from_trail (first_lit (marked ms) trail) ms lrs h p prf x
| from_clause ((lit, r) :: clause_lrs) trail ms lrs h p prf x =
from_reason r lit clause_lrs trail ms lrs h p prf x
and from_reason (Level0 p') lit clause_lrs trail ms lrs h p prf x =
let val (p, prf) = Argo_Proof.mk_unit_res lit p p' prf
in from_clause clause_lrs trail ms lrs h p prf x end
| from_reason r lit clause_lrs trail ms lrs h p prf x =
if level_of r = level then
if marked ms lit then from_clause clause_lrs trail ms lrs h p prf x
else from_clause clause_lrs trail (lit :: ms) lrs (Argo_Heap.increase lit h) p prf x
else
let
val (lrs, h) =
if AList.defined Argo_Lit.eq_id lrs lit then (lrs, h)
else ((lit, r) :: lrs, Argo_Heap.increase lit h)
in from_clause clause_lrs trail ms lrs h p prf x end
and from_trail ((lit, _), _) [_] lrs h p prf x =
let val (lrs, (p, prf)) = reduce lrs p prf
in (Argo_Lit.negate lit :: map fst lrs, lrs, h, p, prf, x) end
| from_trail ((lit, r), trail) ms lrs h p prf x =
let
val (clause_lrs, p', x) = justification_for explain vals lit r x
val (p, prf) = Argo_Proof.mk_unit_res lit p' p prf
in from_clause clause_lrs trail (unmark lit ms) lrs h p prf x end
val (ls, p) = cls
val lrs = if level = 0 then unsat cx cls else map (fn l => (l, the_reason_of vals l)) ls
val (lits, lrs, heap, p, prf, x) = from_clause lrs (snd trail) [] [] heap p prf x
val heap = Argo_Heap.decay heap
val (levels, cx) = learn_clause lrs (lits, p) (mk_context [] level trail vals wts heap clss prf)
in (levels, cx, x) end
fun restart cx = backjump_to 0 cx
end