Theory Set2_Join_RBT

(* Author: Tobias Nipkow *)

section "Join-Based Implementation of Sets via RBTs"

theory Set2_Join_RBT
imports
  Set2_Join
  RBT_Set
begin

subsection "Code"

text ‹
Function joinL› joins two trees (and an element).
Precondition: propbheight l  bheight r.
Method:
Descend along the left spine of r›
until you find a subtree with the same bheight› as l›,
then combine them into a new red node.
›
fun joinL :: "'a rbt  'a  'a rbt  'a rbt" where
"joinL l x r =
  (if bheight l  bheight r then R l x r
   else case r of
     B l' x' r'  baliL (joinL l x l') x' r' |
     R l' x' r'  R (joinL l x l') x' r')"

fun joinR :: "'a rbt  'a  'a rbt  'a rbt" where
"joinR l x r =
  (if bheight l  bheight r then R l x r
   else case l of
     B l' x' r'  baliR l' x' (joinR r' x r) |
     R l' x' r'  R l' x' (joinR r' x r))"

definition join :: "'a rbt  'a  'a rbt  'a rbt" where
"join l x r =
  (if bheight l > bheight r
   then paint Black (joinR l x r)
   else if bheight l < bheight r
   then paint Black (joinL l x r)
   else B l x r)"

declare joinL.simps[simp del]
declare joinR.simps[simp del]


subsection "Properties"

subsubsection "Color and height invariants"

lemma invc2_joinL:
 " invc l; invc r; bheight l  bheight r  
  invc2 (joinL l x r)
   (bheight l  bheight r  color r = Black  invc(joinL l x r))"
proof (induct l x r rule: joinL.induct)
  case (1 l x r) thus ?case
    by(auto simp: invc_baliL invc2I joinL.simps[of l x r] split!: tree.splits if_splits)
qed

lemma invc2_joinR:
  " invc l; invh l; invc r; invh r; bheight l  bheight r  
  invc2 (joinR l x r)
   (bheight l  bheight r  color l = Black  invc(joinR l x r))"
proof (induct l x r rule: joinR.induct)
  case (1 l x r) thus ?case
    by(fastforce simp: invc_baliR invc2I joinR.simps[of l x r] split!: tree.splits if_splits)
qed

lemma bheight_joinL:
  " invh l; invh r; bheight l  bheight r   bheight (joinL l x r) = bheight r"
proof (induct l x r rule: joinL.induct)
  case (1 l x r) thus ?case
    by(auto simp: bheight_baliL joinL.simps[of l x r] split!: tree.split)
qed

lemma invh_joinL:
  " invh l;  invh r;  bheight l  bheight r   invh (joinL l x r)"
proof (induct l x r rule: joinL.induct)
  case (1 l x r) thus ?case
    by(auto simp: invh_baliL bheight_joinL joinL.simps[of l x r] split!: tree.split color.split)
qed

lemma bheight_joinR:
  " invh l;  invh r;  bheight l  bheight r   bheight (joinR l x r) = bheight l"
proof (induct l x r rule: joinR.induct)
  case (1 l x r) thus ?case
    by(fastforce simp: bheight_baliR joinR.simps[of l x r] split!: tree.split)
qed

lemma invh_joinR:
  " invh l; invh r; bheight l  bheight r   invh (joinR l x r)"
proof (induct l x r rule: joinR.induct)
  case (1 l x r) thus ?case
    by(fastforce simp: invh_baliR bheight_joinR joinR.simps[of l x r]
        split!: tree.split color.split)
qed

text ‹All invariants in one:›

lemma inv_joinL: " invc l; invc r; invh l; invh r; bheight l  bheight r 
  invc2 (joinL l x r)  (bheight l  bheight r  color r = Black   invc (joinL l x r))
      invh (joinL l x r)  bheight (joinL l x r) = bheight r"
proof (induct l x r rule: joinL.induct)
  case (1 l x r) thus ?case
    by(auto simp: inv_baliL invc2I joinL.simps[of l x r] split!: tree.splits if_splits)
qed

lemma inv_joinR: " invc l; invc r; invh l; invh r; bheight l  bheight r 
  invc2 (joinR l x r)  (bheight l  bheight r  color l = Black   invc (joinR l x r))
      invh (joinR l x r)  bheight (joinR l x r) = bheight l"
proof (induct l x r rule: joinR.induct)
  case (1 l x r) thus ?case
    by(auto simp: inv_baliR invc2I joinR.simps[of l x r] split!: tree.splits if_splits)
qed

(* unused *)
lemma rbt_join: " invc l; invh l; invc r; invh r   rbt(join l x r)"
by(simp add: inv_joinL inv_joinR invh_paint rbt_def color_paint_Black join_def)

text ‹To make sure the the black height is not increased unnecessarily:›

lemma bheight_paint_Black: "bheight(paint Black t)  bheight t + 1"
by(cases t) auto

lemma " rbt l; rbt r   bheight(join l x r)  max (bheight l) (bheight r) + 1"
using bheight_paint_Black[of "joinL l x r"] bheight_paint_Black[of "joinR l x r"]
  bheight_joinL[of l r x] bheight_joinR[of l r x]
by(auto simp: max_def rbt_def join_def)


subsubsection "Inorder properties"

text "Currently unused. Instead constset_tree and constbst properties are proved directly."

lemma inorder_joinL: "bheight l  bheight r  inorder(joinL l x r) = inorder l @ x # inorder r"
proof(induction l x r rule: joinL.induct)
  case (1 l x r)
  thus ?case by(auto simp: inorder_baliL joinL.simps[of l x r] split!: tree.splits color.splits)
qed

lemma inorder_joinR:
  "inorder(joinR l x r) = inorder l @ x # inorder r"
proof(induction l x r rule: joinR.induct)
  case (1 l x r)
  thus ?case by (force simp: inorder_baliR joinR.simps[of l x r] split!: tree.splits color.splits)
qed

lemma "inorder(join l x r) = inorder l @ x # inorder r"
by(auto simp: inorder_joinL inorder_joinR inorder_paint join_def
      split!: tree.splits color.splits if_splits
      dest!: arg_cong[where f = inorder])


subsubsection "Set and bst properties"

lemma set_baliL:
  "set_tree(baliL l a r) = set_tree l  {a}  set_tree r"
by(cases "(l,a,r)" rule: baliL.cases) (auto)

lemma set_joinL:
  "bheight l  bheight r  set_tree (joinL l x r) = set_tree l  {x}  set_tree r"
proof(induction l x r rule: joinL.induct)
  case (1 l x r)
  thus ?case by(auto simp: set_baliL joinL.simps[of l x r] split!: tree.splits color.splits)
qed

lemma set_baliR:
  "set_tree(baliR l a r) = set_tree l  {a}  set_tree r"
by(cases "(l,a,r)" rule: baliR.cases) (auto)

lemma set_joinR:
  "set_tree (joinR l x r) = set_tree l  {x}  set_tree r"
proof(induction l x r rule: joinR.induct)
  case (1 l x r)
  thus ?case by(force simp: set_baliR joinR.simps[of l x r] split!: tree.splits color.splits)
qed

lemma set_paint: "set_tree (paint c t) = set_tree t"
by (cases t) auto

lemma set_join: "set_tree (join l x r) = set_tree l  {x}  set_tree r"
by(simp add: set_joinL set_joinR set_paint join_def)

lemma bst_baliL:
  "bst l; bst r; xset_tree l. x < a; xset_tree r. a < x
    bst (baliL l a r)"
by(cases "(l,a,r)" rule: baliL.cases) (auto simp: ball_Un)

lemma bst_baliR:
  "bst l; bst r; xset_tree l. x < a; xset_tree r. a < x
    bst (baliR l a r)"
by(cases "(l,a,r)" rule: baliR.cases) (auto simp: ball_Un)

lemma bst_joinL:
  "bst (Node l (a, n) r); bheight l  bheight r
   bst (joinL l a r)"
proof(induction l a r rule: joinL.induct)
  case (1 l a r)
  thus ?case
    by(auto simp: set_baliL joinL.simps[of l a r] set_joinL ball_Un intro!: bst_baliL
        split!: tree.splits color.splits)
qed

lemma bst_joinR:
  "bst l; bst r; xset_tree l. x < a; yset_tree r. a < y 
   bst (joinR l a r)"
proof(induction l a r rule: joinR.induct)
  case (1 l a r)
  thus ?case
    by(auto simp: set_baliR joinR.simps[of l a r] set_joinR ball_Un intro!: bst_baliR
        split!: tree.splits color.splits)
qed

lemma bst_paint: "bst (paint c t) = bst t"
by(cases t) auto

lemma bst_join:
  "bst (Node l (a, n) r)  bst (join l a r)"
by(auto simp: bst_paint bst_joinL bst_joinR join_def)

lemma inv_join: " invc l; invh l; invc r; invh r   invc(join l x r)  invh(join l x r)"
by (simp add: inv_joinL inv_joinR invh_paint join_def)

subsubsection "Interpretation of localeSet2_Join with Red-Black Tree"

global_interpretation RBT: Set2_Join
where join = join and inv = "λt. invc t  invh t"
defines insert_rbt = RBT.insert and delete_rbt = RBT.delete and split_rbt = RBT.split
and join2_rbt = RBT.join2 and split_min_rbt = RBT.split_min
proof (standard, goal_cases)
  case 1 show ?case by (rule set_join)
next
  case 2 thus ?case by (simp add: bst_join)
next
  case 3 show ?case by simp
next
  case 4 thus ?case by (simp add: inv_join)
next
  case 5 thus ?case by simp
qed

text ‹The invariant does not guarantee that the root node is black. This is not required
to guarantee that the height is logarithmic in the size --- Exercise.›

end