Theory Binomial_Heap

(* Author: Peter Lammich
           Tobias Nipkow (tuning)
*)

section ‹Binomial Priority Queue›

theory Binomial_Heap
imports
  "HOL-Library.Pattern_Aliases"
  Complex_Main
  Priority_Queue_Specs
  Time_Funs
begin

text ‹
  We formalize the presentation from Okasaki's book.
  We show the functional correctness and complexity of all operations.

  The presentation is engineered for simplicity, and most
  proofs are straightforward and automatic.
›

subsection ‹Binomial Tree and Forest Types›

datatype 'a tree = Node (rank: nat) (root: 'a) (children: "'a tree list")

type_synonym 'a forest = "'a tree list"

subsubsection ‹Multiset of elements›

fun mset_tree :: "'a::linorder tree  'a multiset" where
  "mset_tree (Node _ a ts) = {#a#} + (t∈#mset ts. mset_tree t)"

definition mset_forest :: "'a::linorder forest  'a multiset" where
  "mset_forest ts = (t∈#mset ts. mset_tree t)"

lemma mset_tree_simp_alt[simp]:
  "mset_tree (Node r a ts) = {#a#} + mset_forest ts"
  unfolding mset_forest_def by auto
declare mset_tree.simps[simp del]

lemma mset_tree_nonempty[simp]: "mset_tree t  {#}"
by (cases t) auto

lemma mset_forest_Nil[simp]:
  "mset_forest [] = {#}"
by (auto simp: mset_forest_def)

lemma mset_forest_Cons[simp]: "mset_forest (t#ts) = mset_tree t + mset_forest ts"
by (auto simp: mset_forest_def)

lemma mset_forest_empty_iff[simp]: "mset_forest ts = {#}  ts=[]"
by (auto simp: mset_forest_def)

lemma root_in_mset[simp]: "root t ∈# mset_tree t"
by (cases t) auto

lemma mset_forest_rev_eq[simp]: "mset_forest (rev ts) = mset_forest ts"
by (auto simp: mset_forest_def)

subsubsection ‹Invariants›

text ‹Binomial tree›
fun btree :: "'a::linorder tree  bool" where
"btree (Node r x ts) 
   (tset ts. btree t)  map rank ts = rev [0..<r]"

text ‹Heap invariant›
fun heap :: "'a::linorder tree  bool" where
"heap (Node _ x ts)  (tset ts. heap t  x  root t)"

definition "bheap t  btree t  heap t"

text ‹Binomial Forest invariant:›
definition "invar ts  (tset ts. bheap t)  (sorted_wrt (<) (map rank ts))"

text ‹A binomial forest is often called a binomial heap, but this overloads the latter term.›

text ‹The children of a binomial heap node are a valid forest:›
lemma invar_children:
  "bheap (Node r v ts)  invar (rev ts)"
  by (auto simp: bheap_def invar_def rev_map[symmetric])


subsection ‹Operations and Their Functional Correctness›

subsubsection link›

context
includes pattern_aliases
begin

fun link :: "('a::linorder) tree  'a tree  'a tree" where
  "link (Node r x1 ts1 =: t1) (Node r' x2 ts2 =: t2) =
    (if x1x2 then Node (r+1) x1 (t2#ts1) else Node (r+1) x2 (t1#ts2))"

end

lemma invar_link:
  assumes "bheap t1"
  assumes "bheap t2"
  assumes "rank t1 = rank t2"
  shows "bheap (link t1 t2)"
using assms unfolding bheap_def
by (cases "(t1, t2)" rule: link.cases) auto

lemma rank_link[simp]: "rank (link t1 t2) = rank t1 + 1"
by (cases "(t1, t2)" rule: link.cases) simp

lemma mset_link[simp]: "mset_tree (link t1 t2) = mset_tree t1 + mset_tree t2"
by (cases "(t1, t2)" rule: link.cases) simp

subsubsection ins_tree›

fun ins_tree :: "'a::linorder tree  'a forest  'a forest" where
  "ins_tree t [] = [t]"
| "ins_tree t1 (t2#ts) =
  (if rank t1 < rank t2 then t1#t2#ts else ins_tree (link t1 t2) ts)"

lemma bheap0[simp]: "bheap (Node 0 x [])"
unfolding bheap_def by auto

lemma invar_Cons[simp]:
  "invar (t#ts)
   bheap t  invar ts  (t'set ts. rank t < rank t')"
by (auto simp: invar_def)

lemma invar_ins_tree:
  assumes "bheap t"
  assumes "invar ts"
  assumes "t'set ts. rank t  rank t'"
  shows "invar (ins_tree t ts)"
using assms
by (induction t ts rule: ins_tree.induct) (auto simp: invar_link less_eq_Suc_le[symmetric])

lemma mset_forest_ins_tree[simp]:
  "mset_forest (ins_tree t ts) = mset_tree t + mset_forest ts"
by (induction t ts rule: ins_tree.induct) auto

lemma ins_tree_rank_bound:
  assumes "t'  set (ins_tree t ts)"
  assumes "t'set ts. rank t0 < rank t'"
  assumes "rank t0 < rank t"
  shows "rank t0 < rank t'"
using assms
by (induction t ts rule: ins_tree.induct) (auto split: if_splits)

subsubsection insert›

hide_const (open) insert

definition insert :: "'a::linorder  'a forest  'a forest" where
"insert x ts = ins_tree (Node 0 x []) ts"

lemma invar_insert[simp]: "invar t  invar (insert x t)"
by (auto intro!: invar_ins_tree simp: insert_def)

lemma mset_forest_insert[simp]: "mset_forest (insert x t) = {#x#} + mset_forest t"
by(auto simp: insert_def)

subsubsection merge›

context
includes pattern_aliases
begin

fun merge :: "'a::linorder forest  'a forest  'a forest" where
  "merge ts1 [] = ts1"
| "merge [] ts2 = ts2"
| "merge (t1#ts1 =: f1) (t2#ts2 =: f2) = (
    if rank t1 < rank t2 then t1 # merge ts1 f2 else
    if rank t2 < rank t1 then t2 # merge f1 ts2
    else ins_tree (link t1 t2) (merge ts1 ts2)
  )"

end

lemma merge_simp2[simp]: "merge [] ts2 = ts2"
by (cases ts2) auto

lemma merge_rank_bound:
  assumes "t'  set (merge ts1 ts2)"
  assumes "t12set ts1  set ts2. rank t < rank t12"
  shows "rank t < rank t'"
using assms
by (induction ts1 ts2 arbitrary: t' rule: merge.induct)
   (auto split: if_splits simp: ins_tree_rank_bound)

lemma invar_merge[simp]:
  assumes "invar ts1"
  assumes "invar ts2"
  shows "invar (merge ts1 ts2)"
using assms
by (induction ts1 ts2 rule: merge.induct)
   (auto 0 3 simp: Suc_le_eq intro!: invar_ins_tree invar_link elim!: merge_rank_bound)


text ‹Longer, more explicit proof of @{thm [source] invar_merge}, 
      to illustrate the application of the @{thm [source] merge_rank_bound} lemma.›
lemma 
  assumes "invar ts1"
  assumes "invar ts2"
  shows "invar (merge ts1 ts2)"
  using assms
proof (induction ts1 ts2 rule: merge.induct)
  case (3 t1 ts1 t2 ts2)
  ― ‹Invariants of the parts can be shown automatically›
  from "3.prems" have [simp]: 
    "bheap t1" "bheap t2"
    (*"invar (merge (t1#ts1) ts2)" 
    "invar (merge ts1 (t2#ts2))"
    "invar (merge ts1 ts2)"*)
    by auto

  ― ‹These are the three cases of the @{const merge} function›
  consider (LT) "rank t1 < rank t2"
         | (GT) "rank t1 > rank t2"
         | (EQ) "rank t1 = rank t2"
    using antisym_conv3 by blast
  then show ?case proof cases
    case LT 
    ― ‹@{const merge} takes the first tree from the left heap›
    then have "merge (t1 # ts1) (t2 # ts2) = t1 # merge ts1 (t2 # ts2)" by simp
    also have "invar " proof (simp, intro conjI)
      ― ‹Invariant follows from induction hypothesis›
      show "invar (merge ts1 (t2 # ts2))"
        using LT "3.IH" "3.prems" by simp

      ― ‹It remains to show that t1 has smallest rank.›
      show "t'set (merge ts1 (t2 # ts2)). rank t1 < rank t'"
        ― ‹Which is done by auxiliary lemma @{thm [source] merge_rank_bound}
        using LT "3.prems" by (force elim!: merge_rank_bound)
    qed
    finally show ?thesis .
  next
    ― ‹@{const merge} takes the first tree from the right heap›
    case GT 
    ― ‹The proof is anaologous to the LT› case›
    then show ?thesis using "3.prems" "3.IH" by (force elim!: merge_rank_bound)
  next
    case [simp]: EQ
    ― ‹@{const merge} links both first forest, and inserts them into the merged remaining heaps›
    have "merge (t1 # ts1) (t2 # ts2) = ins_tree (link t1 t2) (merge ts1 ts2)" by simp
    also have "invar " proof (intro invar_ins_tree invar_link) 
      ― ‹Invariant of merged remaining heaps follows by IH›
      show "invar (merge ts1 ts2)"
        using EQ "3.prems" "3.IH" by auto

      ― ‹For insertion, we have to show that the rank of the linked tree is ≤› the 
          ranks in the merged remaining heaps›
      show "t'set (merge ts1 ts2). rank (link t1 t2)  rank t'"
      proof -
        ― ‹Which is, again, done with the help of @{thm [source] merge_rank_bound}
        have "rank (link t1 t2) = Suc (rank t2)" by simp
        thus ?thesis using "3.prems" by (auto simp: Suc_le_eq elim!: merge_rank_bound)
      qed
    qed simp_all
    finally show ?thesis .
  qed
qed auto


lemma mset_forest_merge[simp]:
  "mset_forest (merge ts1 ts2) = mset_forest ts1 + mset_forest ts2"
by (induction ts1 ts2 rule: merge.induct) auto

subsubsection get_min›

fun get_min :: "'a::linorder forest  'a" where
  "get_min [t] = root t"
| "get_min (t#ts) = min (root t) (get_min ts)"

lemma bheap_root_min:
  assumes "bheap t"
  assumes "x ∈# mset_tree t"
  shows "root t  x"
using assms unfolding bheap_def
by (induction t arbitrary: x rule: mset_tree.induct) (fastforce simp: mset_forest_def)

lemma get_min_mset:
  assumes "ts[]"
  assumes "invar ts"
  assumes "x ∈# mset_forest ts"
  shows "get_min ts  x"
  using assms
apply (induction ts arbitrary: x rule: get_min.induct)
apply (auto
      simp: bheap_root_min min_def intro: order_trans;
      meson linear order_trans bheap_root_min
      )+
done

lemma get_min_member:
  "ts[]  get_min ts ∈# mset_forest ts"
by (induction ts rule: get_min.induct) (auto simp: min_def)

lemma get_min:
  assumes "mset_forest ts  {#}"
  assumes "invar ts"
  shows "get_min ts = Min_mset (mset_forest ts)"
using assms get_min_member get_min_mset
by (auto simp: eq_Min_iff)

subsubsection get_min_rest›

fun get_min_rest :: "'a::linorder forest  'a tree × 'a forest" where
  "get_min_rest [t] = (t,[])"
| "get_min_rest (t#ts) = (let (t',ts') = get_min_rest ts
                     in if root t  root t' then (t,ts) else (t',t#ts'))"

lemma get_min_rest_get_min_same_root:
  assumes "ts[]"
  assumes "get_min_rest ts = (t',ts')"
  shows "root t' = get_min ts"
using assms
by (induction ts arbitrary: t' ts' rule: get_min.induct) (auto simp: min_def split: prod.splits)

lemma mset_get_min_rest:
  assumes "get_min_rest ts = (t',ts')"
  assumes "ts[]"
  shows "mset ts = {#t'#} + mset ts'"
using assms
by (induction ts arbitrary: t' ts' rule: get_min.induct) (auto split: prod.splits if_splits)

lemma set_get_min_rest:
  assumes "get_min_rest ts = (t', ts')"
  assumes "ts[]"
  shows "set ts = Set.insert t' (set ts')"
using mset_get_min_rest[OF assms, THEN arg_cong[where f=set_mset]]
by auto

lemma invar_get_min_rest:
  assumes "get_min_rest ts = (t',ts')"
  assumes "ts[]"
  assumes "invar ts"
  shows "bheap t'" and "invar ts'"
proof -
  have "bheap t'  invar ts'"
    using assms
    proof (induction ts arbitrary: t' ts' rule: get_min.induct)
      case (2 t v va)
      then show ?case
        apply (clarsimp split: prod.splits if_splits)
        apply (drule set_get_min_rest; fastforce)
        done
    qed auto
  thus "bheap t'" and "invar ts'" by auto
qed

subsubsection del_min›

definition del_min :: "'a::linorder forest  'a::linorder forest" where
"del_min ts = (case get_min_rest ts of
   (Node r x ts1, ts2)  merge (itrev ts1 []) ts2)"

lemma invar_del_min[simp]:
  assumes "ts  []"
  assumes "invar ts"
  shows "invar (del_min ts)"
using assms
unfolding del_min_def itrev_Nil
by (auto
      split: prod.split tree.split
      intro!: invar_merge invar_children 
      dest: invar_get_min_rest
    )

lemma mset_forest_del_min:
  assumes "ts  []"
  shows "mset_forest ts = mset_forest (del_min ts) + {# get_min ts #}"
using assms
unfolding del_min_def itrev_Nil
apply (clarsimp split: tree.split prod.split)
apply (frule (1) get_min_rest_get_min_same_root)
apply (frule (1) mset_get_min_rest)
apply (auto simp: mset_forest_def)
done


subsubsection ‹Instantiating the Priority Queue Locale›

text ‹Last step of functional correctness proof: combine all the above lemmas
to show that binomial heaps satisfy the specification of priority queues with merge.›

interpretation bheaps: Priority_Queue_Merge
  where empty = "[]" and is_empty = "(=) []" and insert = insert
  and get_min = get_min and del_min = del_min and merge = merge
  and invar = invar and mset = mset_forest
proof (unfold_locales, goal_cases)
  case 1 thus ?case by simp
next
  case 2 thus ?case by auto
next
  case 3 thus ?case by auto
next
  case (4 q)
  thus ?case using mset_forest_del_min[of q] get_min[OF _ invar q]
    by (auto simp: union_single_eq_diff)
next
  case (5 q) thus ?case using get_min[of q] by auto
next
  case 6 thus ?case by (auto simp add: invar_def)
next
  case 7 thus ?case by simp
next
  case 8 thus ?case by simp
next
  case 9 thus ?case by simp
next
  case 10 thus ?case by simp
qed


subsection ‹Complexity›

text ‹The size of a binomial tree is determined by its rank›
lemma size_mset_btree:
  assumes "btree t"
  shows "size (mset_tree t) = 2^rank t"
  using assms
proof (induction t)
  case (Node r v ts)
  hence IH: "size (mset_tree t) = 2^rank t" if "t  set ts" for t
    using that by auto

  from Node have COMPL: "map rank ts = rev [0..<r]" by auto

  have "size (mset_forest ts) = (tts. size (mset_tree t))"
    by (induction ts) auto
  also have " = (tts. 2^rank t)" using IH
    by (auto cong: map_cong)
  also have " = (rmap rank ts. 2^r)"
    by (induction ts) auto
  also have " = (i{0..<r}. 2^i)"
    unfolding COMPL
    by (auto simp: rev_map[symmetric] interv_sum_list_conv_sum_set_nat)
  also have " = 2^r - 1"
    by (induction r) auto
  finally show ?case
    by (simp)
qed

lemma size_mset_tree:
  assumes "bheap t"
  shows "size (mset_tree t) = 2^rank t"
using assms unfolding bheap_def
by (simp add: size_mset_btree)

text ‹The length of a binomial heap is bounded by the number of its elements›
lemma size_mset_forest:
  assumes "invar ts"
  shows "length ts  log 2 (size (mset_forest ts) + 1)"
proof -
  from invar ts have
    ASC: "sorted_wrt (<) (map rank ts)" and
    TINV: "tset ts. bheap t"
    unfolding invar_def by auto

  have "(2::nat)^length ts = (i{0..<length ts}. 2^i) + 1"
    by (simp add: sum_power2)
  also have " = (i[0..<length ts]. 2^i) + 1" (is "_ = ?S + 1")
    by (simp add: interv_sum_list_conv_sum_set_nat)
  also have "?S  (tts. 2^rank t)" (is "_  ?T")
    using sorted_wrt_less_idx[OF ASC] by(simp add: sum_list_mono2)
  also have "?T + 1  (tts. size (mset_tree t)) + 1" using TINV
    by (auto cong: map_cong simp: size_mset_tree)
  also have " = size (mset_forest ts) + 1"
    unfolding mset_forest_def by (induction ts) auto
  finally have "2^length ts  size (mset_forest ts) + 1" by simp
  then show ?thesis using le_log2_of_power by blast
qed

subsubsection ‹Timing Functions›

time_fun link

lemma T_link[simp]: "T_link t1 t2 = 0"
by(cases t1; cases t2, auto)

time_fun rank

lemma T_rank[simp]: "T_rank t = 0"
by(cases t, auto)

time_fun ins_tree

time_fun insert

lemma T_ins_tree_simple_bound: "T_ins_tree t ts  length ts + 1"
by (induction t ts rule: T_ins_tree.induct) auto

subsubsection T_insert›

lemma T_insert_bound:
  assumes "invar ts"
  shows "T_insert x ts  log 2 (size (mset_forest ts) + 1) + 1"
proof -
  have "real (T_insert x ts)  real (length ts) + 1"
    unfolding T_insert.simps using T_ins_tree_simple_bound
    by (metis of_nat_1 of_nat_add of_nat_mono) 
  also note size_mset_forest[OF invar ts]
  finally show ?thesis by simp
qed

subsubsection T_merge›

time_fun merge

(* Warning: ‹T_merge.induct› is less convenient than the equivalent ‹merge.induct›,
apparently because of the ‹let› clauses introduced by pattern_aliases; should be investigated.
*)

text ‹A crucial idea is to estimate the time in correlation with the
  result length, as each carry reduces the length of the result.›

lemma T_ins_tree_length:
  "T_ins_tree t ts + length (ins_tree t ts) = 2 + length ts"
by (induction t ts rule: ins_tree.induct) auto

lemma T_merge_length:
  "T_merge ts1 ts2 + length (merge ts1 ts2)  2 * (length ts1 + length ts2) + 1"
by (induction ts1 ts2 rule: merge.induct)
   (auto simp: T_ins_tree_length algebra_simps)

text ‹Finally, we get the desired logarithmic bound›
lemma T_merge_bound:
  fixes ts1 ts2
  defines "n1  size (mset_forest ts1)"
  defines "n2  size (mset_forest ts2)"
  assumes "invar ts1" "invar ts2"
  shows "T_merge ts1 ts2  4*log 2 (n1 + n2 + 1) + 1"
proof -
  note n_defs = assms(1,2)

  have "T_merge ts1 ts2  2 * real (length ts1) + 2 * real (length ts2) + 1"
    using T_merge_length[of ts1 ts2] by simp
  also note size_mset_forest[OF invar ts1]
  also note size_mset_forest[OF invar ts2]
  finally have "T_merge ts1 ts2  2 * log 2 (n1 + 1) + 2 * log 2 (n2 + 1) + 1"
    unfolding n_defs by (simp add: algebra_simps)
  also have "log 2 (n1 + 1)  log 2 (n1 + n2 + 1)" 
    unfolding n_defs by (simp add: algebra_simps)
  also have "log 2 (n2 + 1)  log 2 (n1 + n2 + 1)" 
    unfolding n_defs by (simp add: algebra_simps)
  finally show ?thesis by (simp add: algebra_simps)
qed

subsubsection T_get_min›

time_fun root

lemma T_root[simp]: "T_root t = 0"
by(cases t)(simp_all)

time_fun min

time_fun get_min

lemma T_get_min_estimate: "ts[]  T_get_min ts = length ts"
by (induction ts rule: T_get_min.induct) auto

lemma T_get_min_bound:
  assumes "invar ts"
  assumes "ts[]"
  shows "T_get_min ts  log 2 (size (mset_forest ts) + 1)"
proof -
  have 1: "T_get_min ts = length ts" using assms T_get_min_estimate by auto
  also note size_mset_forest[OF invar ts]
  finally show ?thesis .
qed

subsubsection T_del_min›

time_fun get_min_rest

lemma T_get_min_rest_estimate: "ts[]  T_get_min_rest ts = length ts"
  by (induction ts rule: T_get_min_rest.induct) auto

lemma T_get_min_rest_bound:
  assumes "invar ts"
  assumes "ts[]"
  shows "T_get_min_rest ts  log 2 (size (mset_forest ts) + 1)"
proof -
  have 1: "T_get_min_rest ts = length ts" using assms T_get_min_rest_estimate by auto
  also note size_mset_forest[OF invar ts]
  finally show ?thesis .
qed

time_fun del_min

lemma T_del_min_bound:
  fixes ts
  defines "n  size (mset_forest ts)"
  assumes "invar ts" and "ts[]"
  shows "T_del_min ts  6 * log 2 (n+1) + 2"
proof -
  obtain r x ts1 ts2 where GM: "get_min_rest ts = (Node r x ts1, ts2)"
    by (metis surj_pair tree.exhaust_sel)

  have I1: "invar (rev ts1)" and I2: "invar ts2"
    using invar_get_min_rest[OF GM ts[] invar ts] invar_children
    by auto

  define n1 where "n1 = size (mset_forest ts1)"
  define n2 where "n2 = size (mset_forest ts2)"

  have "n1  n" "n1 + n2  n" unfolding n_def n1_def n2_def
    using mset_get_min_rest[OF GM ts[]]
    by (auto simp: mset_forest_def)

  have "T_del_min ts = real (T_get_min_rest ts) + real (T_itrev ts1 []) + real (T_merge (rev ts1) ts2)"
    unfolding T_del_min.simps GM T_itrev itrev_Nil
    by simp
  also have "T_get_min_rest ts  log 2 (n+1)" 
    using T_get_min_rest_bound[OF invar ts ts[]] unfolding n_def by simp
  also have "T_itrev ts1 []  1 + log 2 (n1 + 1)"
    unfolding T_itrev n1_def using size_mset_forest[OF I1] by simp
  also have "T_merge (rev ts1) ts2  4*log 2 (n1 + n2 + 1) + 1"
    unfolding n1_def n2_def using T_merge_bound[OF I1 I2] by (simp add: algebra_simps)
  finally have "T_del_min ts  log 2 (n+1) + log 2 (n1 + 1) + 4*log 2 (real (n1 + n2) + 1) + 2"
    by (simp add: algebra_simps)
  also note n1 + n2  n
  also note n1  n
  finally show ?thesis by (simp add: algebra_simps)
qed

end