Theory Selection

(*
  File:    Data_Structures/Selection.thy
  Author:  Manuel Eberl, TU München
*)
section ‹The Median-of-Medians Selection Algorithm›
theory Selection
  imports Complex_Main Time_Funs Sorting
begin

text ‹
  Note that there is significant overlap between this theory (which is intended mostly for the
  Functional Data Structures book) and the Median-of-Medians AFP entry.
›

subsection ‹Auxiliary material›

lemma replicate_numeral: "replicate (numeral n) x = x # replicate (pred_numeral n) x"
  by (simp add: numeral_eq_Suc)

lemma insort_correct: "insort xs = sort xs"
  using sorted_insort mset_insort by (metis properties_for_sort)

lemma sum_list_replicate [simp]: "sum_list (replicate n x) = n * x"
  by (induction n) auto

lemma mset_concat: "mset (concat xss) = sum_list (map mset xss)"
  by (induction xss) simp_all

lemma set_mset_sum_list [simp]: "set_mset (sum_list xs) = (xset xs. set_mset x)"
  by (induction xs) auto

lemma filter_mset_image_mset:
  "filter_mset P (image_mset f A) = image_mset f (filter_mset (λx. P (f x)) A)"
  by (induction A) auto

lemma filter_mset_sum_list: "filter_mset P (sum_list xs) = sum_list (map (filter_mset P) xs)"
  by (induction xs) simp_all

lemma sum_mset_mset_mono: 
  assumes "(x. x ∈# A  f x ⊆# g x)"
  shows   "(x∈#A. f x) ⊆# (x∈#A. g x)"
  using assms by (induction A) (auto intro!: subset_mset.add_mono)

lemma mset_filter_mono:
  assumes "A ⊆# B" "x. x ∈# A  P x  Q x"
  shows   "filter_mset P A ⊆# filter_mset Q B"
  by (rule mset_subset_eqI) (insert assms, auto simp: mset_subset_eq_count count_eq_zero_iff)

lemma size_mset_sum_mset_distrib: "size (sum_mset A :: 'a multiset) = sum_mset (image_mset size A)"
  by (induction A) auto

lemma sum_mset_mono:
  assumes "x. x ∈# A  f x  (g x :: 'a :: {ordered_ab_semigroup_add,comm_monoid_add})"
  shows   "(x∈#A. f x)  (x∈#A. g x)"
  using assms by (induction A) (auto intro!: add_mono)

lemma filter_mset_is_empty_iff: "filter_mset P A = {#}  (x. x ∈# A  ¬P x)"
  by (auto simp: multiset_eq_iff count_eq_zero_iff)

lemma sort_eq_Nil_iff [simp]: "sort xs = []  xs = []"
  by (metis set_empty set_sort)

lemma sort_mset_cong: "mset xs = mset ys  sort xs = sort ys"
  by (metis sorted_list_of_multiset_mset)

lemma Min_set_sorted: "sorted xs  xs  []  Min (set xs) = hd xs"
  by (cases xs; force intro: Min_insert2)

lemma hd_sort:
  fixes xs :: "'a :: linorder list"
  shows "xs  []  hd (sort xs) = Min (set xs)"
  by (subst Min_set_sorted [symmetric]) auto

lemma length_filter_conv_size_filter_mset: "length (filter P xs) = size (filter_mset P (mset xs))"
  by (induction xs) auto

lemma sorted_filter_less_subset_take:
  assumes "sorted xs" and "i < length xs"
  shows   "{#x ∈# mset xs. x < xs ! i#} ⊆# mset (take i xs)"
  using assms
proof (induction xs arbitrary: i rule: list.induct)
  case (Cons x xs i)
  show ?case
  proof (cases i)
    case 0
    thus ?thesis using Cons.prems by (auto simp: filter_mset_is_empty_iff)
  next
    case (Suc i')
    have "{#y ∈# mset (x # xs). y < (x # xs) ! i#} ⊆# add_mset x {#y ∈# mset xs. y < xs ! i'#}"
      using Suc Cons.prems by (auto)
    also have " ⊆# add_mset x (mset (take i' xs))"
      unfolding mset_subset_eq_add_mset_cancel using Cons.prems Suc
      by (intro Cons.IH) (auto)
    also have " = mset (take i (x # xs))" by (simp add: Suc)
    finally show ?thesis .
  qed
qed auto

lemma sorted_filter_greater_subset_drop:
  assumes "sorted xs" and "i < length xs"
  shows   "{#x ∈# mset xs. x > xs ! i#} ⊆# mset (drop (Suc i) xs)"
  using assms
proof (induction xs arbitrary: i rule: list.induct)
  case (Cons x xs i)
  show ?case
  proof (cases i)
    case 0
    thus ?thesis by (auto simp: sorted_append filter_mset_is_empty_iff)
  next
    case (Suc i')
    have "{#y ∈# mset (x # xs). y > (x # xs) ! i#} ⊆# {#y ∈# mset xs. y > xs ! i'#}"
      using Suc Cons.prems by (auto simp: set_conv_nth)
    also have " ⊆# mset (drop (Suc i') xs)"
      using Cons.prems Suc by (intro Cons.IH) (auto)
    also have " = mset (drop (Suc i) (x # xs))" by (simp add: Suc)
    finally show ?thesis .
  qed
qed auto


subsection ‹Chopping a list into equally-sized bits›

fun chop :: "nat  'a list  'a list list" where
  "chop 0 _  = []"
| "chop _ [] = []"
| "chop n xs = take n xs # chop n (drop n xs)"

lemmas [simp del] = chop.simps

text ‹
  This is an alternative induction rule for constchop, which is often nicer to use.
›
lemma chop_induct' [case_names trivial reduce]:
  assumes "n xs. n = 0  xs = []  P n xs"
  assumes "n xs. n > 0  xs  []  P n (drop n xs)  P n xs"
  shows   "P n xs"
  using assms
proof induction_schema
  show "wf (measure (length  snd))"
    by auto
qed (blast | simp)+

lemma chop_eq_Nil_iff [simp]: "chop n xs = []  n = 0  xs = []"
  by (induction n xs rule: chop.induct; subst chop.simps) auto

lemma chop_0 [simp]: "chop 0 xs = []"
  by (simp add: chop.simps)

lemma chop_Nil [simp]: "chop n [] = []"
  by (cases n) (auto simp: chop.simps)

lemma chop_reduce: "n > 0  xs  []  chop n xs = take n xs # chop n (drop n xs)"
  by (cases n; cases xs) (auto simp: chop.simps)

lemma concat_chop [simp]: "n > 0  concat (chop n xs) = xs"
  by (induction n xs rule: chop.induct; subst chop.simps) auto

lemma chop_elem_not_Nil [dest]: "ys  set (chop n xs)  ys  []"
  by (induction n xs rule: chop.induct; subst (asm) chop.simps)
     (auto simp: eq_commute[of "[]"] split: if_splits)

lemma length_chop_part_le: "ys  set (chop n xs)  length ys  n"
  by (induction n xs rule: chop.induct; subst (asm) chop.simps) (auto split: if_splits)

lemma length_chop:
  assumes "n > 0"
  shows   "length (chop n xs) = nat length xs / n"
proof -
  from n > 0 have "real n * length (chop n xs)  length xs"
    by (induction n xs rule: chop.induct; subst chop.simps) (auto simp: field_simps)
  moreover from n > 0 have "real n * length (chop n xs) < length xs + n"
    by (induction n xs rule: chop.induct; subst chop.simps)
       (auto simp: field_simps split: nat_diff_split_asm)+
  ultimately have "length (chop n xs)  length xs / n" and "length (chop n xs) < length xs / n + 1"
    using assms by (auto simp: field_simps)
  thus ?thesis by linarith
qed

lemma sum_msets_chop: "n > 0  (yschop n xs. mset ys) = mset xs"
  by (subst mset_concat [symmetric]) simp_all

lemma UN_sets_chop: "n > 0  (ysset (chop n xs). set ys) = set xs"
  by (simp only: set_concat [symmetric] concat_chop)

lemma chop_append: "d dvd length xs  chop d (xs @ ys) = chop d xs @ chop d ys"
  by (induction d xs rule: chop_induct') (auto simp: chop_reduce dvd_imp_le)

lemma chop_replicate [simp]: "d > 0  chop d (replicate d xs) = [replicate d xs]"
  by (subst chop_reduce) auto

lemma chop_replicate_dvd [simp]:
  assumes "d dvd n"
  shows   "chop d (replicate n x) = replicate (n div d) (replicate d x)"
proof (cases "d = 0")
  case False
  from assms obtain k where k: "n = d * k"
    by blast
  have "chop d (replicate (d * k) x) = replicate k (replicate d x)"
    using False by (induction k) (auto simp: replicate_add chop_append)
  thus ?thesis using False by (simp add: k)
qed auto

lemma chop_concat:
  assumes "xsset xss. length xs = d" and "d > 0"
  shows   "chop d (concat xss) = xss"
  using assms 
proof (induction xss)
  case (Cons xs xss)
  have "chop d (concat (xs # xss)) = chop d (xs @ concat xss)"
    by simp
  also have " = chop d xs @ chop d (concat xss)"
    using Cons.prems by (intro chop_append) auto
  also have "chop d xs = [xs]"
    using Cons.prems by (subst chop_reduce) auto
  also have "chop d (concat xss) = xss"
    using Cons.prems by (intro Cons.IH) auto
  finally show ?case by simp
qed auto


subsection ‹Selection›

definition select :: "nat  ('a :: linorder) list  'a" where
  "select k xs = sort xs ! k"

lemma select_0: "xs  []  select 0 xs = Min (set xs)"
  by (simp add: hd_sort select_def flip: hd_conv_nth)

lemma select_mset_cong: "mset xs = mset ys  select k xs = select k ys"
  using sort_mset_cong[of xs ys] unfolding select_def by auto

lemma select_in_set [intro,simp]:
  assumes "k < length xs"
  shows   "select k xs  set xs"
proof -
  from assms have "sort xs ! k  set (sort xs)" by (intro nth_mem) auto
  also have "set (sort xs) = set xs" by simp
  finally show ?thesis by (simp add: select_def)
qed

lemma
  assumes "n < length xs"
  shows   size_less_than_select: "size {#y ∈# mset xs. y < select n xs#}  n"
    and   size_greater_than_select: "size {#y ∈# mset xs. y > select n xs#} < length xs - n"
proof -
  have "size {#y ∈# mset (sort xs). y < select n xs#}  size (mset (take n (sort xs)))"
    unfolding select_def using assms
    by (intro size_mset_mono sorted_filter_less_subset_take) auto
  thus "size {#y ∈# mset xs. y < select n xs#}  n"
    by simp
  have "size {#y ∈# mset (sort xs). y > select n xs#}  size (mset (drop (Suc n) (sort xs)))"
    unfolding select_def using assms
    by (intro size_mset_mono sorted_filter_greater_subset_drop) auto
  thus "size {#y ∈# mset xs. y > select n xs#} < length xs - n"
    using assms by simp
qed


subsection ‹The designated median of a list›

definition median where "median xs = select ((length xs - 1) div 2) xs"

lemma median_in_set [intro, simp]: 
  assumes "xs  []"
  shows   "median xs  set xs"
proof -
  from assms have "length xs > 0" by auto
  hence "(length xs - 1) div 2 < length xs" by linarith
  thus ?thesis by (simp add: median_def)
qed

lemma size_less_than_median: "size {#y ∈# mset xs. y < median xs#}  (length xs - 1) div 2"
proof (cases "xs = []")
  case False
  hence "length xs > 0"
    by auto
  hence less: "(length xs - 1) div 2 < length xs"
    by linarith
  show "size {#y ∈# mset xs. y < median xs#}  (length xs - 1) div 2"
    using size_less_than_select[OF less] by (simp add: median_def)
qed auto

lemma size_greater_than_median: "size {#y ∈# mset xs. y > median xs#}  length xs div 2"
proof (cases "xs = []")
  case False
  hence "length xs > 0"
    by auto
  hence less: "(length xs - 1) div 2 < length xs"
    by linarith
  have "size {#y ∈# mset xs. y > median xs#} < length xs - (length xs - 1) div 2"
    using size_greater_than_select[OF less] by (simp add: median_def)
  also have " = length xs div 2 + 1"
    using length xs > 0 by linarith
  finally show "size {#y ∈# mset xs. y > median xs#}  length xs div 2"
    by simp
qed auto

lemmas median_props = size_less_than_median size_greater_than_median


subsection ‹A recurrence for selection›

definition partition3 :: "'a  'a :: linorder list  'a list × 'a list × 'a list" where
  "partition3 x xs = (filter (λy. y < x) xs, filter (λy. y = x) xs, filter (λy. y > x) xs)"

lemma partition3_code [code]:
  "partition3 x [] = ([], [], [])"
  "partition3 x (y # ys) =
     (case partition3 x ys of (ls, es, gs) 
        if y < x then (y # ls, es, gs) else if x = y then (ls, y # es, gs) else (ls, es, y # gs))"
  by (auto simp: partition3_def)

lemma length_partition3:
  assumes "partition3 x xs = (ls, es, gs)"
  shows   "length xs = length ls + length es + length gs"
  using assms by (induction xs arbitrary: ls es gs)
                 (auto simp: partition3_code split: if_splits prod.splits)

lemma sort_append:
  assumes "xset xs. yset ys. x  y"
  shows   "sort (xs @ ys) = sort xs @ sort ys"
  using assms by (intro properties_for_sort) (auto simp: sorted_append)

lemma select_append:
  assumes "yset ys. zset zs. y  z"
  shows   "k < length ys  select k (ys @ zs) = select k ys"
    and   "k  {length ys..<length ys + length zs} 
             select k (ys @ zs) = select (k - length ys) zs"
  using assms by (simp_all add: select_def sort_append nth_append)

lemma select_append':
  assumes "yset ys. zset zs. y  z" and "k < length ys + length zs"
  shows   "select k (ys @ zs) = (if k < length ys then select k ys else select (k - length ys) zs)"
  using assms by (auto intro!: select_append)

theorem select_rec_partition:
  assumes "k < length xs"
  shows "select k xs = (
           let (ls, es, gs) = partition3 x xs
           in
             if k < length ls then select k ls 
             else if k < length ls + length es then x
             else select (k - length ls - length es) gs
          )" (is "_ = ?rhs")
proof -
  define ls es gs where "ls = filter (λy. y < x) xs" and "es = filter (λy. y = x) xs"
                    and "gs = filter (λy. y > x) xs"
  define nl ne where [simp]: "nl = length ls" "ne = length es"
  have mset_eq: "mset xs = mset ls + mset es + mset gs"
    unfolding ls_def es_def gs_def by (induction xs) auto
  have length_eq: "length xs = length ls + length es + length gs"
    unfolding ls_def es_def gs_def 
    using [[simp_depth_limit = 1]] by (induction xs) auto
  have [simp]: "select i es = x" if "i < length es" for i
  proof -
    have "select i es  set (sort es)" unfolding select_def
      using that by (intro nth_mem) auto
    thus ?thesis
      by (auto simp: es_def)
  qed

  have "select k xs = select k (ls @ (es @ gs))"
    by (intro select_mset_cong) (simp_all add: mset_eq)
  also have " = (if k < nl then select k ls else select (k - nl) (es @ gs))" 
    unfolding nl_ne_def using assms
    by (intro select_append') (auto simp: ls_def es_def gs_def length_eq)
  also have " = (if k < nl then select k ls else if k < nl + ne then x
                    else select (k - nl - ne) gs)"
  proof (rule if_cong)
    assume "¬k < nl"
    have "select (k - nl) (es @ gs) =
                 (if k - nl < ne then select (k - nl) es else select (k - nl - ne) gs)"
      unfolding nl_ne_def using assms ¬k < nl
      by (intro select_append') (auto simp: ls_def es_def gs_def length_eq)
    also have " = (if k < nl + ne then x else select (k - nl - ne) gs)"
      using ¬k < nl by auto
    finally show "select (k - nl) (es @ gs) = " .
  qed simp_all
  also have " = ?rhs"
    by (simp add: partition3_def ls_def es_def gs_def)
  finally show ?thesis .
qed


subsection ‹The size of the lists in the recursive calls›

text ‹
  We now derive an upper bound for the number of elements of a list that are smaller
  (resp. bigger) than the median of medians with chopping size 5. To avoid having to do the
  same proof twice, we do it generically for an operation ≺› that we will later instantiate
  with either <› or >›.
›

context
  fixes xs :: "'a :: linorder list"
  fixes M defines "M  median (map median (chop 5 xs))"
begin

lemma size_median_of_medians_aux:
  fixes R :: "'a :: linorder  'a  bool" (infix "" 50)
  assumes R: "R  {(<), (>)}"
  shows "size {#y ∈# mset xs. y  M#}  nat 0.7 * length xs + 3"
proof -
  define n and m where [simp]: "n = length xs" and "m = length (chop 5 xs)"
  text ‹We define an abbreviation for the multiset of all the chopped-up groups:›

  text ‹We then split that multiset into those groups whose medians is less than @{term M}
        and the rest.›
  define Y_small ("Y") where "Y = filter_mset (λys. median ys  M) (mset (chop 5 xs))"
  define Y_big ("Y") where "Y = filter_mset (λys. ¬(median ys  M)) (mset (chop 5 xs))"
  have "m = size (mset (chop 5 xs))" by (simp add: m_def)
  also have "mset (chop 5 xs) = Y + Y" unfolding Y_small_def Y_big_def
    by (rule multiset_partition)
  finally have m_eq: "m = size Y + size Y" by simp

  text ‹At most half of the lists have a median that is smaller than the median of medians:›
  have "size Y = size (image_mset median Y)" by simp
  also have "image_mset median Y = {#y ∈# mset (map median (chop 5 xs)). y  M#}"
    unfolding Y_small_def by (subst filter_mset_image_mset [symmetric]) simp_all
  also have "size   (length (map median (chop 5 xs))) div 2"
    unfolding M_def using median_props[of "map median (chop 5 xs)"] R by auto
  also have " = m div 2" by (simp add: m_def)
  finally have size_Y_small: "size Y  m div 2" .

  text ‹We estimate the number of elements less than @{term M} by grouping them into elements
      coming from @{term "Y"} and elements coming from @{term "Y"}:›
  have "{#y ∈# mset xs. y  M#} = {#y ∈# (yschop 5 xs. mset ys). y  M#}"
    by (subst sum_msets_chop) simp_all
  also have " = (yschop 5 xs. {#y ∈# mset ys. y  M#})"
    by (subst filter_mset_sum_list) (simp add: o_def)
  also have " = (ys∈#mset (chop 5 xs). {#y ∈# mset ys. y  M#})"
    by (subst sum_mset_sum_list [symmetric]) simp_all
  also have "mset (chop 5 xs) = Y + Y"
    by (simp add: Y_small_def Y_big_def not_le)
  also have "(ys∈#. {#y ∈# mset ys. y  M#}) = 
               (ys∈#Y. {#y ∈# mset ys. y  M#}) + (ys∈#Y. {#y ∈# mset ys. y  M#})"
    by simp

  text ‹Next, we overapproximate the elements contributed by @{term "Y"}: instead of those elements
        that are smaller than the median, we take ‹all› the elements of each group.
        For the elements contributed by @{term "Y"}, we overapproximate by taking all those that
        are less than their median instead of only those that are less than @{term M}.›
  also have " ⊆# (ys∈#Y. mset ys) + (ys∈#Y. {#y ∈# mset ys. y  median ys#})"
    using R
    by (intro subset_mset.add_mono sum_mset_mset_mono mset_filter_mono) (auto simp: Y_big_def)
  finally have "size {# y ∈# mset xs. y  M#}  size "
    by (rule size_mset_mono)
  hence "size {# y ∈# mset xs. y  M#} 
           (ys∈#Y. length ys) + (ys∈#Y. size {#y ∈# mset ys. y  median ys#})"
    by (simp add: size_mset_sum_mset_distrib multiset.map_comp o_def)

  text ‹Next, we further overapproximate the first sum by noting that each group has
        at most size 5.›
  also have "(ys∈#Y. length ys)  (ys∈#Y. 5)"
    by (intro sum_mset_mono) (auto simp: Y_small_def length_chop_part_le)
  also have " = 5 * size Y" by simp

  text ‹Next, we note that each group in @{term "Y"} can have at most 2 elements that are
        smaller than its median.›
  also have "(ys∈#Y. size {#y ∈# mset ys. y  median ys#}) 
               (ys∈#Y. length ys div 2)"
  proof (intro sum_mset_mono, goal_cases)
    fix ys assume "ys ∈# Y"
    hence "ys  []"
      by (auto simp: Y_big_def)
    thus "size {#y ∈# mset ys. y  median ys#}  length ys div 2"
      using R median_props[of ys] by auto
  qed
  also have "  (ys∈#Y. 2)"
    by (intro sum_mset_mono div_le_mono diff_le_mono)
       (auto simp: Y_big_def dest: length_chop_part_le)
  also have " = 2 * size Y" by simp

  text ‹Simplifying gives us the main result.›
  also have "5 * size Y + 2 * size Y = 2 * m + 3 * size Y"
    by (simp add: m_eq)
  also have "  3.5 * m"
    using size Y  m div 2 by linarith
  also have " = 3.5 * n / 5"
    by (simp add: m_def length_chop)
  also have "  0.7 * n + 3.5"
    by linarith
  finally have "size {#y ∈# mset xs. y  M#}  0.7 * n + 3.5"
    by simp
  thus "size {#y ∈# mset xs. y  M#}  nat 0.7 * n + 3"
    by linarith
qed

lemma size_less_than_median_of_medians:
  "size {#y ∈# mset xs. y < M#}  nat 0.7 * length xs + 3"
  using size_median_of_medians_aux[of "(<)"] by simp

lemma size_greater_than_median_of_medians:
  "size {#y ∈# mset xs. y > M#}  nat 0.7 * length xs + 3"
  using size_median_of_medians_aux[of "(>)"] by simp

end


subsection ‹Efficient algorithm›

text ‹
  We handle the base cases and computing the median for the chopped-up sublists of size 5
  using the naive selection algorithm where we sort the list using insertion sort.
›
definition slow_select where
  "slow_select k xs = insort xs ! k"

definition slow_median where
  "slow_median xs = slow_select ((length xs - 1) div 2) xs"

lemma slow_select_correct: "slow_select k xs = select k xs"
  by (simp add: slow_select_def select_def insort_correct)

lemma slow_median_correct: "slow_median xs = median xs"
  by (simp add: median_def slow_median_def slow_select_correct)

text ‹
  The definition of the selection algorithm is complicated somewhat by the fact that its
  termination is contingent on its correctness: if the first recursive call were to return an
  element for x› that is e.g. smaller than all list elements, the algorithm would not terminate.

  Therefore, we first prove partial correctness, then termination, and then combine the two
  to obtain total correctness.
›
function mom_select where
  "mom_select k xs = (
     if length xs  20 then
       slow_select k xs
     else
       let M = mom_select (((length xs + 4) div 5 - 1) div 2) (map slow_median (chop 5 xs));
           (ls, es, gs) = partition3 M xs
       in
         if k < length ls then mom_select k ls 
         else if k < length ls + length es then M
         else mom_select (k - length ls - length es) gs
      )"
  by auto

text ‹
  If @{const "mom_select"} terminates, it agrees with @{const select}:
›
lemma mom_select_correct_aux:
  assumes "mom_select_dom (k, xs)" and "k < length xs"
  shows   "mom_select k xs = select k xs"
  using assms
proof (induction rule: mom_select.pinduct)
  case (1 k xs)
  show "mom_select k xs = select k xs"
  proof (cases "length xs  20")
    case True
    thus "mom_select k xs = select k xs" using "1.prems" "1.hyps"
      by (subst mom_select.psimps) (auto simp: select_def slow_select_correct)
  next
    case False
    define x where
      "x = mom_select (((length xs + 4) div 5 - 1) div 2) (map slow_median (chop 5 xs))"
    define ls es gs where "ls = filter (λy. y < x) xs" and "es = filter (λy. y = x) xs"
                      and "gs = filter (λy. y > x) xs"
    define nl ne where "nl = length ls" and "ne = length es"
    note defs = nl_def ne_def x_def ls_def es_def gs_def
    have tw: "(ls, es, gs) = partition3 x xs"
      unfolding partition3_def defs One_nat_def ..
    have length_eq: "length xs = nl + ne + length gs"
      unfolding nl_def ne_def ls_def es_def gs_def
      using [[simp_depth_limit = 1]] by (induction xs) auto
    note IH = "1.IH"(2,3)[OF False x_def tw refl refl]

    have "mom_select k xs = (if k < nl then mom_select k ls else if k < nl + ne then x
                                else mom_select (k - nl - ne) gs)" using "1.hyps" False
      by (subst mom_select.psimps) (simp_all add: partition3_def flip: defs One_nat_def)
    also have " = (if k < nl then select k ls else if k < nl + ne then x 
                       else select (k - nl - ne) gs)"
      using IH length_eq "1.prems" by (simp add: ls_def es_def gs_def nl_def ne_def)
    also have " = select k xs" using k < length xs
      by (subst (3) select_rec_partition[of _ _ x]) (simp_all add: nl_def ne_def flip: tw)
    finally show "mom_select k xs = select k xs" .
  qed
qed

text @{const mom_select} indeed terminates for all inputs:
›
lemma mom_select_termination: "All mom_select_dom"
proof (relation "measure (length  snd)"; (safe)?)
  fix k :: nat and xs :: "'a list"
  assume "¬ length xs  20"
  thus "((((length xs + 4) div 5 - 1) div 2, map slow_median (chop 5 xs)), k, xs)
            measure (length  snd)"
    by (auto simp: length_chop nat_less_iff ceiling_less_iff)
next
  fix k :: nat and xs ls es gs :: "'a list"
  define x where "x = mom_select (((length xs + 4) div 5 - 1) div 2) (map slow_median (chop 5 xs))"
  assume A: "¬ length xs  20" 
            "(ls, es, gs) = partition3 x xs"
            "mom_select_dom (((length xs + 4) div 5 - 1) div 2, 
                             map slow_median (chop 5 xs))"
  have less: "((length xs + 4) div 5 - 1) div 2 < nat length xs / 5"
    using A(1) by linarith

  text ‹For termination, it suffices to prove that @{term x} is in the list.›
  have "x = select (((length xs + 4) div 5 - 1) div 2) (map slow_median (chop 5 xs))"
    using less unfolding x_def by (intro mom_select_correct_aux A) (auto simp: length_chop)
  also have "  set (map slow_median (chop 5 xs))"
    using less by (intro select_in_set) (simp_all add: length_chop)
  also have "  set xs"
    unfolding set_map
  proof safe
    fix ys assume ys: "ys  set (chop 5 xs)"
    hence "median ys  set ys"
      by auto
    also have "set ys  (set ` set (chop 5 xs))"
      using ys by blast
    also have " = set xs"
      by (rule UN_sets_chop) simp_all
    finally show "slow_median ys  set xs"
      by (simp add: slow_median_correct)
  qed
  finally have "x  set xs" .
  thus "((k, ls), k, xs)  measure (length  snd)"
   and "((k - length ls - length es, gs), k, xs)  measure (length  snd)"
    using A(1,2) by (auto simp: partition3_def intro!: length_filter_less[of x])
qed

termination mom_select by (rule mom_select_termination)

lemmas [simp del] = mom_select.simps

lemma mom_select_correct: "k < length xs  mom_select k xs = select k xs"
  using mom_select_correct_aux and mom_select_termination by blast



subsection ‹Running time analysis›

fun T_partition3 :: "'a  'a list  nat" where
  "T_partition3 x [] = 1"
| "T_partition3 x (y # ys) = T_partition3 x ys + 1"

lemma T_partition3_eq: "T_partition3 x xs = length xs + 1"
  by (induction x xs rule: T_partition3.induct) auto


time_definition slow_select

lemmas T_slow_select_def [simp del] = T_slow_select.simps


definition T_slow_median :: "'a :: linorder list  nat" where
  "T_slow_median xs = T_length xs + T_slow_select ((length xs - 1) div 2) xs"

lemma T_slow_select_le:
  assumes "k < length xs"
  shows   "T_slow_select k xs  length xs ^ 2 + 3 * length xs + 1"
proof -
  have "T_slow_select k xs = T_insort xs + T_nth (Sorting.insort xs) k"
    unfolding T_slow_select_def ..
  also have "T_insort xs  (length xs + 1) ^ 2"
    by (rule T_insort_length)
  also have "T_nth (Sorting.insort xs) k = k + 1"
    using assms by (subst T_nth_eq) (auto simp: length_insort)
  also have "k + 1  length xs"
    using assms by linarith
  also have "(length xs + 1) ^ 2 + length xs = length xs ^ 2 + 3 * length xs + 1"
    by (simp add: algebra_simps power2_eq_square)
  finally show ?thesis by - simp_all
qed

lemma T_slow_median_le:
  assumes "xs  []"
  shows   "T_slow_median xs  length xs ^ 2 + 4 * length xs + 2"
proof -
  have "T_slow_median xs = length xs + T_slow_select ((length xs - 1) div 2) xs + 1"
    by (simp add: T_slow_median_def T_length_eq)
  also from assms have "length xs > 0"
    by simp
  hence "(length xs - 1) div 2 < length xs"
    by linarith
  hence "T_slow_select ((length xs - 1) div 2) xs  length xs ^ 2 + 3 * length xs + 1"
    by (intro T_slow_select_le) auto
  also have "length xs +  + 1 = length xs ^ 2 + 4 * length xs + 2"
    by (simp add: algebra_simps)
  finally show ?thesis by - simp_all
qed


time_fun chop

lemmas [simp del] = T_chop.simps

lemma T_chop_Nil [simp]: "T_chop d [] = 1"
  by (cases d) (auto simp: T_chop.simps)

lemma T_chop_0 [simp]: "T_chop 0 xs = 1"
  by (auto simp: T_chop.simps)

lemma T_chop_reduce:
  "n > 0  xs  []  T_chop n xs = T_take n xs + T_drop n xs + T_chop n (drop n xs) + 1"
  by (cases n; cases xs) (auto simp: T_chop.simps)

lemma T_chop_le: "T_chop d xs  5 * length xs + 1"
  by (induction d xs rule: T_chop.induct) (auto simp: T_chop_reduce T_take_eq T_drop_eq)


text ‹
  The option domintros› here allows us to explicitly reason about where the function does and
  does not terminate. With this, we can skip the termination proof this time because we can
  reuse the one for constmom_select.
›
function (domintros) T_mom_select :: "nat  'a :: linorder list  nat" where
  "T_mom_select k xs = T_length xs + (
     if length xs  20 then
       T_slow_select k xs
     else
       let xss = chop 5 xs;
           ms = map slow_median xss;
           idx = (((length xs + 4) div 5 - 1) div 2);
           x = mom_select idx ms;
           (ls, es, gs) = partition3 x xs;
           nl = length ls;
           ne = length es
       in
         (if k < nl then T_mom_select k ls 
          else T_length es + (if k < nl + ne then 0 else T_mom_select (k - nl - ne) gs)) +
         T_mom_select idx ms + T_chop 5 xs + T_map T_slow_median xss +
         T_partition3 x xs + T_length ls + 1
      )"
  by auto

termination T_mom_select
proof (rule allI, safe)
  fix k :: nat and xs :: "'a :: linorder list"
  have "mom_select_dom (k, xs)"
    using mom_select_termination by blast
  thus "T_mom_select_dom (k, xs)"
    by (induction k xs rule: mom_select.pinduct)
       (rule T_mom_select.domintros, simp_all)
qed

lemmas [simp del] = T_mom_select.simps


function T'_mom_select :: "nat  nat" where
  "T'_mom_select n =
     (if n  20 then
        482
      else
        T'_mom_select (nat 0.2*n) + T'_mom_select (nat 0.7*n+3) + 19 * n + 54)"
  by force+
termination by (relation "measure id"; simp; linarith)

lemmas [simp del] = T'_mom_select.simps


lemma T'_mom_select_ge: "T'_mom_select n  482"
  by (induction n rule: T'_mom_select.induct; subst T'_mom_select.simps) auto

lemma T'_mom_select_mono:
  "m  n  T'_mom_select m  T'_mom_select n"
proof (induction n arbitrary: m rule: less_induct)
  case (less n m)
  show ?case
  proof (cases "m  20")
    case True
    hence "T'_mom_select m = 482"
      by (subst T'_mom_select.simps) auto
    also have "  T'_mom_select n"
      by (rule T'_mom_select_ge)
    finally show ?thesis .
  next
    case False
    hence "T'_mom_select m =
             T'_mom_select (nat 0.2*m) + T'_mom_select (nat 0.7*m + 3) + 19 * m + 54"
      by (subst T'_mom_select.simps) auto
    also have "  T'_mom_select (nat 0.2*n) + T'_mom_select (nat 0.7*n + 3) + 19 * n + 54"
      using m  n and False by (intro add_mono less.IH; linarith)
    also have " = T'_mom_select n"
      using m  n and False by (subst T'_mom_select.simps) auto
    finally show ?thesis .
  qed
qed

lemma T_mom_select_le_aux:
  assumes "k < length xs"
  shows   "T_mom_select k xs  T'_mom_select (length xs)"
  using assms
proof (induction k xs rule: T_mom_select.induct)
  case (1 k xs)
  define n where [simp]: "n = length xs"
  define x where
    "x = mom_select (((length xs + 4) div 5 - 1) div 2) (map slow_median (chop 5 xs))"
  define ls es gs where "ls = filter (λy. y < x) xs" and "es = filter (λy. y = x) xs"
                    and "gs = filter (λy. y > x) xs"
  define nl ne where "nl = length ls" and "ne = length es"
  note defs = nl_def ne_def x_def ls_def es_def gs_def
  have tw: "(ls, es, gs) = partition3 x xs"
    unfolding partition3_def defs One_nat_def ..
  note IH = "1.IH"(1,2,3)[OF _ refl refl refl x_def tw refl refl refl refl]

  show ?case
  proof (cases "length xs  20")
    case True ― ‹base case›
    hence "T_mom_select k xs  (length xs)2 + 4 * length xs + 2"
      using T_slow_select_le[of k xs] k < length xs
      by (subst T_mom_select.simps) (auto simp: T_length_eq)
    also have "  202 + 4 * 20 + 2"
      using True by (intro add_mono power_mono) auto
    also have " = 482"
      by simp
    also have " = T'_mom_select (length xs)"
      using True by (simp add: T'_mom_select.simps)
    finally show ?thesis by simp
  next
    case False ― ‹recursive case›
    have "((n + 4) div 5 - 1) div 2 < nat n / 5"
      using False unfolding n_def by linarith
    hence "x = select (((n + 4) div 5 - 1) div 2) (map slow_median (chop 5 xs))"
      unfolding x_def n_def by (intro mom_select_correct) (auto simp: length_chop)
    also have "((n + 4) div 5 - 1) div 2 = (nat n / 5 - 1) div 2"
      by linarith
    also have "select  (map slow_median (chop 5 xs)) = median (map slow_median (chop 5 xs))"
      by (auto simp: median_def length_chop)
    finally have x_eq: "x = median (map slow_median (chop 5 xs))" .

    text ‹The cost of computing the medians of all the subgroups:›
    define T_ms where "T_ms = T_map T_slow_median (chop 5 xs)"
    have "T_ms  10 * n + 48"
    proof -
      have "T_ms = (yschop 5 xs. T_slow_median ys) + length (chop 5 xs) + 1"
        by (simp add: T_ms_def T_map_eq)
      also have "(yschop 5 xs. T_slow_median ys)  (yschop 5 xs. 47)"
      proof (intro sum_list_mono)
        fix ys assume "ys  set (chop 5 xs)"
        hence "length ys  5" "ys  []"
          using length_chop_part_le[of ys 5 xs] by auto
        from ys  [] have "T_slow_median ys  (length ys) ^ 2 + 4 * length ys + 2"
          by (rule T_slow_median_le)
        also have "  5 ^ 2 + 4 * 5 + 2"
          using length ys  5 by (intro add_mono power_mono) auto
        finally show "T_slow_median ys  47" by simp
      qed
      also have "(yschop 5 xs. 47) + length (chop 5 xs) + 1 =
                   48 * nat real n / 5 + 1"
        by (simp add: map_replicate_const length_chop)
      also have "  10 * n + 48"
        by linarith
      finally show "T_ms  10 * n + 48" by simp
    qed

    text ‹The cost of the first recursive call (to compute the median of medians):›
    define T_rec1 where
      "T_rec1 = T_mom_select (((length xs + 4) div 5 - 1) div 2) (map slow_median (chop 5 xs))"
    from False have "((length xs + 4) div 5 - Suc 0) div 2 < nat real (length xs) / 5"
      by linarith
    hence "T_rec1  T'_mom_select (length (map slow_median (chop 5 xs)))"
      using False unfolding T_rec1_def by (intro IH(3)) (auto simp: length_chop)
    hence "T_rec1  T'_mom_select (nat 0.2 * n)"
      by (simp add: length_chop)

    text ‹The cost of the second recursive call (to compute the final result):›
    define T_rec2 where "T_rec2 = (if k < nl then T_mom_select k ls
                                 else if k < nl + ne then 0
                                 else T_mom_select (k - nl - ne) gs)"
    consider "k < nl" | "k  {nl..<nl+ne}" | "k  nl+ne"
      by force
    hence "T_rec2  T'_mom_select (nat 0.7 * n + 3)"
    proof cases
      assume "k < nl"
      hence "T_rec2 = T_mom_select k ls"
        by (simp add: T_rec2_def)
      also have "  T'_mom_select (length ls)"
        by (rule IH(1)) (use k < nl False in auto simp: defs)
      also have "length ls  nat 0.7 * n + 3"
        unfolding ls_def using size_less_than_median_of_medians[of xs]
        by (auto simp: length_filter_conv_size_filter_mset slow_median_correct[abs_def] x_eq)
      hence "T'_mom_select (length ls)  T'_mom_select (nat 0.7 * n + 3)"
        by (rule T'_mom_select_mono)
      finally show ?thesis .
    next
      assume "k  {nl..<nl + ne}"
      hence "T_rec2 = 0"
        by (simp add: T_rec2_def)
      thus ?thesis
        using T'_mom_select_ge[of "nat 0.7 * n + 3"] by simp
    next
      assume "k  nl + ne"
      hence "T_rec2 = T_mom_select (k - nl - ne) gs"
        by (simp add: T_rec2_def)
      also have "  T'_mom_select (length gs)"
        unfolding nl_def ne_def 
      proof (rule IH(2))
        show "¬ length xs  20"
          using False by auto
        show "¬ k < length ls" "¬k < length ls + length es"
          using k  nl + ne by (auto simp: nl_def ne_def)
        have "length xs = nl + ne + length gs"
          unfolding defs by (rule length_partition3) (simp_all add: partition3_def)
        thus "k - length ls - length es < length gs"
          using k  nl + ne k < length xs by (auto simp: nl_def ne_def)
      qed
      also have "length gs  nat 0.7 * n + 3"
        unfolding gs_def using size_greater_than_median_of_medians[of xs]
        by (auto simp: length_filter_conv_size_filter_mset slow_median_correct[abs_def] x_eq)
      hence "T'_mom_select (length gs)  T'_mom_select (nat 0.7 * n + 3)"
        by (rule T'_mom_select_mono)
      finally show ?thesis .
    qed

    text ‹Now for the final inequality chain:›
    have "T_mom_select k xs  T_rec2 + T_rec1 + T_ms + 2 * n + nl + ne + T_chop 5 xs + 5" using False
      by (subst T_mom_select.simps, unfold Let_def tw [symmetric] defs [symmetric])
         (simp_all add: nl_def ne_def T_rec1_def T_rec2_def T_partition3_eq
                        T_length_eq T_ms_def)
    also have "nl  n" by (simp add: nl_def ls_def)
    also have "ne  n" by (simp add: ne_def es_def)
    also note T_ms  10 * n + 48
    also have "T_chop 5 xs  5 * n + 1"
      using T_chop_le[of 5 xs] by simp 
    also note T_rec1  T'_mom_select (nat 0.2*n)
    also note T_rec2  T'_mom_select (nat 0.7*n + 3)
    finally have "T_mom_select k xs 
                    T'_mom_select (nat 0.7*n + 3) + T'_mom_select (nat 0.2*n) + 19 * n + 54"
      by simp
    also have " = T'_mom_select n"
      using False by (subst T'_mom_select.simps) auto
    finally show ?thesis by simp
  qed
qed

subsection ‹Akra--Bazzi Light›

lemma akra_bazzi_light_aux1:
  fixes a b :: real and n n0 :: nat
  assumes ab: "a > 0" "a < 1" "n > n0"
  assumes "n0  (max 0 b + 1) / (1 - a)"
  shows "nat a*n+b < n"
proof -
  have "a * real n + max 0 b  0"
    using ab by simp
  hence "real (nat a*n+b)  a * n + max 0 b + 1"
    by linarith
  also {
    have "n0  (max 0 b + 1) / (1 - a)"
      by fact
    also have " < real n"
      using assms by simp
    finally have "a * real n + max 0 b + 1 < real n"
      using ab by (simp add: field_simps)
  }
  finally show "nat a*n+b < n"
    using n > n0 by linarith
qed

lemma akra_bazzi_light_aux2:
  fixes f :: "nat  real"
  fixes n0 :: nat and a b c d :: real and C1 C2 C1 C2 :: real
  assumes bounds: "a > 0" "c > 0" "a + c < 1" "C1  0"
  assumes rec: "n>n0. f n = f (nat a*n+b) + f (nat c*n+d) + C1 * n + C2"
  assumes ineqs: "n0 > (max 0 b + max 0 d + 2) / (1 - a - c)"
                 "C3  C1 / (1 - a - c)"
                 "C3  (C1 * n0 + C2 + C4) / ((1 - a - c) * n0 - max 0 b - max 0 d - 2)"
                 "nn0. f n  C4"
  shows   "f n  C3 * n + C4"
proof (induction n rule: less_induct)
  case (less n)
  have "0  C1 / (1 - a - c)"
    using bounds by auto
  also have "  C3"
    by fact
  finally have "C3  0" .

  show ?case
  proof (cases "n > n0")
    case False
    hence "f n  C4"
      using ineqs(4) by auto
    also have "  C3 * real n + C4"
      using bounds C3  0 by auto
    finally show ?thesis .
  next
    case True
    have nonneg: "a * n  0" "c * n  0"
      using bounds by simp_all

    have "(max 0 b + 1) / (1 - a)  (max 0 b + max 0 d + 2) / (1 - a - c)"
      using bounds by (intro frac_le) auto
    hence "n0  (max 0 b + 1) / (1 - a)"
      using ineqs(1) by linarith
    hence rec_less1: "nat a*n+b < n"
      using bounds n > n0 by (intro akra_bazzi_light_aux1[of _ n0]) auto

    have "(max 0 d + 1) / (1 - c)  (max 0 b + max 0 d + 2) / (1 - a - c)"
      using bounds by (intro frac_le) auto
    hence "n0  (max 0 d + 1) / (1 - c)"
      using ineqs(1) by linarith
    hence rec_less2: "nat c*n+d < n"
      using bounds n > n0 by (intro akra_bazzi_light_aux1[of _ n0]) auto

    have "f n = f (nat a*n+b) + f (nat c*n+d) + C1 * n + C2"
      using n > n0 by (subst rec) auto
    also have "  (C3 * nat a*n+b + C4) + (C3 * nat c*n+d + C4) + C1 * n + C2"
      using rec_less1 rec_less2 by (intro add_mono less.IH) auto
    also have "  (C3 * (a*n+max 0 b+1) + C4) + (C3 * (c*n+max 0 d+1) + C4) + C1 * n + C2"
      using bounds C3  0 nonneg by (intro add_mono mult_left_mono order.refl; linarith)      
    also have " = C3 * n  +  ((C3 * (max 0 b + max 0 d + 2) + 2 * C4 + C2) -
                                 (C3 * (1 - a - c) - C1) * n)"
      by (simp add: algebra_simps)
    also have "  C3 * n  +  ((C3 * (max 0 b + max 0 d + 2) + 2 * C4 + C2) -
                                 (C3 * (1 - a - c) - C1) * n0)"
      using n > n0 ineqs(2) bounds
      by (intro add_mono diff_mono order.refl mult_left_mono) (auto simp: field_simps)
    also have "(C3 * (max 0 b + max 0 d + 2) + 2 * C4 + C2) - (C3 * (1 - a - c) - C1) * n0  C4"
      using ineqs bounds by (simp add: field_simps)
    finally show "f n  C3 * real n + C4"
      by (simp add: mult_right_mono)
  qed
qed

lemma akra_bazzi_light:
  fixes f :: "nat  real"
  fixes n0 :: nat and a b c d C1 C2 :: real
  assumes bounds: "a > 0" "c > 0" "a + c < 1" "C1  0"
  assumes rec: "n>n0. f n = f (nat a*n+b) + f (nat c*n+d) + C1 * n + C2"
  shows "C3 C4. n. f n  C3 * real n + C4"
proof -
  define n0' where "n0' = max n0 (nat (max 0 b + max 0 d + 2) / (1 - a - c) + 1)"
  define C4 where "C4 = Max (f ` {..n0'})"
  define C3 where "C3 = max (C1 / (1 - a - c))
                         ((C1 * n0' + C2 + C4) / ((1 - a - c) * n0' - max 0 b - max 0 d - 2))"

  have "f n  C3 * n + C4" for n
  proof (rule akra_bazzi_light_aux2[OF bounds _])
    show "n>n0'. f n = f (nat a*n+b) + f (nat c*n+d) + C1 * n + C2"
      using rec by (auto simp: n0'_def)
  next
    show "C3  C1 / (1 - a - c)" 
     and "C3  (C1 * n0' + C2 + C4) / ((1 - a - c) * n0' - max 0 b - max 0 d - 2)"
      by (simp_all add: C3_def)
  next
    have "(max 0 b + max 0 d + 2) / (1 - a - c) < nat (max 0 b + max 0 d + 2) / (1 - a - c) + 1"
      by linarith
    also have "  n0'"
      by (simp add: n0'_def)
    finally show "(max 0 b + max 0 d + 2) / (1 - a - c) < real n0'" .
  next
    show "nn0'. f n  C4"
      by (auto simp: C4_def)
  qed
  thus ?thesis by blast
qed

lemma akra_bazzi_light_nat:
  fixes f :: "nat  nat"
  fixes n0 :: nat and a b c d :: real and C1 C2 :: nat
  assumes bounds: "a > 0" "c > 0" "a + c < 1" "C1  0"
  assumes rec: "n>n0. f n = f (nat a*n+b) + f (nat c*n+d) + C1 * n + C2"
  shows "C3 C4. n. f n  C3 * n + C4"
proof -
  have "C3 C4. n. real (f n)  C3 * real n + C4"
    using assms by (intro akra_bazzi_light[of a c C1 n0 f b d C2]) auto
  then obtain C3 C4 where le: "n. real (f n)  C3 * real n + C4"
    by blast
  have "f n  nat C3 * n + nat C4" for n
  proof -
    have "real (f n)  C3 * real n + C4"
      using le by blast
    also have "  real (nat C3) * real n + real (nat C4)"
      by (intro add_mono mult_right_mono; linarith)
    also have " = real (nat C3 * n + nat C4)"
      by simp
    finally show ?thesis by linarith
  qed
  thus ?thesis by blast
qed

lemma T'_mom_select_le': "C1 C2. n. T'_mom_select n  C1 * n + C2"
proof (rule akra_bazzi_light_nat)
  show "n>20. T'_mom_select n = T'_mom_select (nat 0.2 * n + 0) +
                 T'_mom_select (nat 0.7 * n + 3) + 19 * n + 54"
    using T'_mom_select.simps by auto
qed auto

end