Theory HOL-Library.DAList_Multiset

(*  Title:      HOL/Library/DAList_Multiset.thy
    Author:     Lukas Bulwahn, TU Muenchen
*)

section ‹Multisets partially implemented by association lists›

theory DAList_Multiset
imports Multiset DAList
begin

text ‹Delete prexisting code equations›

declare [[code drop: "{#}" Multiset.is_empty add_mset
  "plus :: 'a multiset  _" "minus :: 'a multiset  _"
  inter_mset union_mset image_mset filter_mset count
  "size :: _ multiset  nat" sum_mset prod_mset
  set_mset sorted_list_of_multiset subset_mset subseteq_mset
  equal_multiset_inst.equal_multiset]]
    

text ‹Raw operations on lists›

definition join_raw ::
    "('key  'val × 'val  'val) 
      ('key × 'val) list  ('key × 'val) list  ('key × 'val) list"
  where "join_raw f xs ys = foldr (λ(k, v). map_default k v (λv'. f k (v', v))) ys xs"

lemma join_raw_Nil [simp]: "join_raw f xs [] = xs"
  by (simp add: join_raw_def)

lemma join_raw_Cons [simp]:
  "join_raw f xs ((k, v) # ys) = map_default k v (λv'. f k (v', v)) (join_raw f xs ys)"
  by (simp add: join_raw_def)

lemma map_of_join_raw:
  assumes "distinct (map fst ys)"
  shows "map_of (join_raw f xs ys) x =
    (case map_of xs x of
      None  map_of ys x
    | Some v  (case map_of ys x of None  Some v | Some v'  Some (f x (v, v'))))"
  using assms
  apply (induct ys)
  apply (auto simp add: map_of_map_default split: option.split)
  apply (metis map_of_eq_None_iff option.simps(2) weak_map_of_SomeI)
  apply (metis Some_eq_map_of_iff map_of_eq_None_iff option.simps(2))
  done

lemma distinct_join_raw:
  assumes "distinct (map fst xs)"
  shows "distinct (map fst (join_raw f xs ys))"
  using assms
proof (induct ys)
  case Nil
  then show ?case by simp
next
  case (Cons y ys)
  then show ?case by (cases y) (simp add: distinct_map_default)
qed

definition "subtract_entries_raw xs ys = foldr (λ(k, v). AList.map_entry k (λv'. v' - v)) ys xs"

lemma map_of_subtract_entries_raw:
  assumes "distinct (map fst ys)"
  shows "map_of (subtract_entries_raw xs ys) x =
    (case map_of xs x of
      None  None
    | Some v  (case map_of ys x of None  Some v | Some v'  Some (v - v')))"
  using assms
  unfolding subtract_entries_raw_def
  apply (induct ys)
  apply auto
  apply (simp split: option.split)
  apply (simp add: map_of_map_entry)
  apply (auto split: option.split)
  apply (metis map_of_eq_None_iff option.simps(3) option.simps(4))
  apply (metis map_of_eq_None_iff option.simps(4) option.simps(5))
  done

lemma distinct_subtract_entries_raw:
  assumes "distinct (map fst xs)"
  shows "distinct (map fst (subtract_entries_raw xs ys))"
  using assms
  unfolding subtract_entries_raw_def
  by (induct ys) (auto simp add: distinct_map_entry)


text ‹Operations on alists with distinct keys›

lift_definition join :: "('a  'b × 'b  'b)  ('a, 'b) alist  ('a, 'b) alist  ('a, 'b) alist"
  is join_raw
  by (simp add: distinct_join_raw)

lift_definition subtract_entries :: "('a, ('b :: minus)) alist  ('a, 'b) alist  ('a, 'b) alist"
  is subtract_entries_raw
  by (simp add: distinct_subtract_entries_raw)


text ‹Implementing multisets by means of association lists›

definition count_of :: "('a × nat) list  'a  nat"
  where "count_of xs x = (case map_of xs x of None  0 | Some n  n)"

lemma count_of_multiset: "finite {x. 0 < count_of xs x}"
proof -
  let ?A = "{x::'a. 0 < (case map_of xs x of None  0::nat | Some n  n)}"
  have "?A  dom (map_of xs)"
  proof
    fix x
    assume "x  ?A"
    then have "0 < (case map_of xs x of None  0::nat | Some n  n)"
      by simp
    then have "map_of xs x  None"
      by (cases "map_of xs x") auto
    then show "x  dom (map_of xs)"
      by auto
  qed
  with finite_dom_map_of [of xs] have "finite ?A"
    by (auto intro: finite_subset)
  then show ?thesis
    by (simp add: count_of_def fun_eq_iff)
qed

lemma count_simps [simp]:
  "count_of [] = (λ_. 0)"
  "count_of ((x, n) # xs) = (λy. if x = y then n else count_of xs y)"
  by (simp_all add: count_of_def fun_eq_iff)

lemma count_of_empty: "x  fst ` set xs  count_of xs x = 0"
  by (induct xs) (simp_all add: count_of_def)

lemma count_of_filter: "count_of (List.filter (P  fst) xs) x = (if P x then count_of xs x else 0)"
  by (induct xs) auto

lemma count_of_map_default [simp]:
  "count_of (map_default x b (λx. x + b) xs) y =
    (if x = y then count_of xs x + b else count_of xs y)"
  unfolding count_of_def by (simp add: map_of_map_default split: option.split)

lemma count_of_join_raw:
  "distinct (map fst ys) 
    count_of xs x + count_of ys x = count_of (join_raw (λx (x, y). x + y) xs ys) x"
  unfolding count_of_def by (simp add: map_of_join_raw split: option.split)

lemma count_of_subtract_entries_raw:
  "distinct (map fst ys) 
    count_of xs x - count_of ys x = count_of (subtract_entries_raw xs ys) x"
  unfolding count_of_def by (simp add: map_of_subtract_entries_raw split: option.split)


text ‹Code equations for multiset operations›

definition Bag :: "('a, nat) alist  'a multiset"
  where "Bag xs = Abs_multiset (count_of (DAList.impl_of xs))"

code_datatype Bag

lemma count_Bag [simp, code]: "count (Bag xs) = count_of (DAList.impl_of xs)"
  by (simp add: Bag_def count_of_multiset)

lemma Mempty_Bag [code]: "{#} = Bag (DAList.empty)"
  by (simp add: multiset_eq_iff alist.Alist_inverse DAList.empty_def)

lift_definition is_empty_Bag_impl :: "('a, nat) alist  bool" is
  "λxs. list_all (λx. snd x = 0) xs" .

lemma is_empty_Bag [code]: "Multiset.is_empty (Bag xs)  is_empty_Bag_impl xs"
proof -
  have "Multiset.is_empty (Bag xs)  (x. count (Bag xs) x = 0)"
    unfolding Multiset.is_empty_def multiset_eq_iff by simp
  also have "  (xfst ` set (alist.impl_of xs). count (Bag xs) x = 0)"
  proof (intro iffI allI ballI)
    fix x assume A: "xfst ` set (alist.impl_of xs). count (Bag xs) x = 0"
    thus "count (Bag xs) x = 0"
    proof (cases "x  fst ` set (alist.impl_of xs)")
      case False
      thus ?thesis by (force simp: count_of_def split: option.splits)
    qed (insert A, auto)
  qed simp_all
  also have "  list_all (λx. snd x = 0) (alist.impl_of xs)" 
    by (auto simp: count_of_def list_all_def)
  finally show ?thesis by (simp add: is_empty_Bag_impl.rep_eq)
qed

lemma union_Bag [code]: "Bag xs + Bag ys = Bag (join (λx (n1, n2). n1 + n2) xs ys)"
  by (rule multiset_eqI)
    (simp add: count_of_join_raw alist.Alist_inverse distinct_join_raw join_def)

lemma add_mset_Bag [code]: "add_mset x (Bag xs) =
    Bag (join (λx (n1, n2). n1 + n2) (DAList.update x 1 DAList.empty) xs)"
  unfolding add_mset_add_single[of x "Bag xs"] union_Bag[symmetric]
  by (simp add: multiset_eq_iff update.rep_eq empty.rep_eq)

lemma minus_Bag [code]: "Bag xs - Bag ys = Bag (subtract_entries xs ys)"
  by (rule multiset_eqI)
    (simp add: count_of_subtract_entries_raw alist.Alist_inverse
      distinct_subtract_entries_raw subtract_entries_def)

lemma filter_Bag [code]: "filter_mset P (Bag xs) = Bag (DAList.filter (P  fst) xs)"
  by (rule multiset_eqI) (simp add: count_of_filter DAList.filter.rep_eq)


lemma mset_eq [code]: "HOL.equal (m1::'a::equal multiset) m2  m1 ⊆# m2  m2 ⊆# m1"
  by (metis equal_multiset_def subset_mset.order_eq_iff)

text ‹By default the code for <› is propxs < ys  xs  ys  ¬ xs = ys.
With equality implemented by ≤›, this leads to three calls of  ≤›.
Here is a more efficient version:›
lemma mset_less[code]: "xs ⊂# (ys :: 'a multiset)  xs ⊆# ys  ¬ ys ⊆# xs"
  by (rule subset_mset.less_le_not_le)

lemma mset_less_eq_Bag0:
  "Bag xs ⊆# A  ((x, n)  set (DAList.impl_of xs). count_of (DAList.impl_of xs) x  count A x)"
    (is "?lhs  ?rhs")
proof
  assume ?lhs
  then show ?rhs by (auto simp add: subseteq_mset_def)
next
  assume ?rhs
  show ?lhs
  proof (rule mset_subset_eqI)
    fix x
    from ?rhs have "count_of (DAList.impl_of xs) x  count A x"
      by (cases "x  fst ` set (DAList.impl_of xs)") (auto simp add: count_of_empty)
    then show "count (Bag xs) x  count A x" by (simp add: subset_mset_def)
  qed
qed

lemma mset_less_eq_Bag [code]:
  "Bag xs ⊆# (A :: 'a multiset)  ((x, n)  set (DAList.impl_of xs). n  count A x)"
proof -
  {
    fix x n
    assume "(x,n)  set (DAList.impl_of xs)"
    then have "count_of (DAList.impl_of xs) x = n"
    proof transfer
      fix x n
      fix xs :: "('a × nat) list"
      show "(distinct  map fst) xs  (x, n)  set xs  count_of xs x = n"
      proof (induct xs)
        case Nil
        then show ?case by simp
      next
        case (Cons ym ys)
        obtain y m where ym: "ym = (y,m)" by force
        note Cons = Cons[unfolded ym]
        show ?case
        proof (cases "x = y")
          case False
          with Cons show ?thesis
            unfolding ym by auto
        next
          case True
          with Cons(2-3) have "m = n" by force
          with True show ?thesis
            unfolding ym by auto
        qed
      qed
    qed
  }
  then show ?thesis
    unfolding mset_less_eq_Bag0 by auto
qed

declare inter_mset_def [code]
declare union_mset_def [code]
declare mset.simps [code]


fun fold_impl :: "('a  nat  'b  'b)  'b  ('a × nat) list  'b"
where
  "fold_impl fn e ((a,n) # ms) = (fold_impl fn ((fn a n) e) ms)"
| "fold_impl fn e [] = e"

context
begin

qualified definition fold :: "('a  nat  'b  'b)  'b  ('a, nat) alist  'b"
  where "fold f e al = fold_impl f e (DAList.impl_of al)"

end

context comp_fun_commute
begin

lemma DAList_Multiset_fold:
  assumes fn: "a n x. fn a n x = (f a ^^ n) x"
  shows "fold_mset f e (Bag al) = DAList_Multiset.fold fn e al"
  unfolding DAList_Multiset.fold_def
proof (induct al)
  fix ys
  let ?inv = "{xs :: ('a × nat) list. (distinct  map fst) xs}"
  note cs[simp del] = count_simps
  have count[simp]: "x. count (Abs_multiset (count_of x)) = count_of x"
    by (rule Abs_multiset_inverse) (simp add: count_of_multiset)
  assume ys: "ys  ?inv"
  then show "fold_mset f e (Bag (Alist ys)) = fold_impl fn e (DAList.impl_of (Alist ys))"
    unfolding Bag_def unfolding Alist_inverse[OF ys]
  proof (induct ys arbitrary: e rule: list.induct)
    case Nil
    show ?case
      by (rule trans[OF arg_cong[of _ "{#}" "fold_mset f e", OF multiset_eqI]])
         (auto, simp add: cs)
  next
    case (Cons pair ys e)
    obtain a n where pair: "pair = (a,n)"
      by force
    from fn[of a n] have [simp]: "fn a n = (f a ^^ n)"
      by auto
    have inv: "ys  ?inv"
      using Cons(2) by auto
    note IH = Cons(1)[OF inv]
    define Ys where "Ys = Abs_multiset (count_of ys)"
    have id: "Abs_multiset (count_of ((a, n) # ys)) = (((+) {# a #}) ^^ n) Ys"
      unfolding Ys_def
    proof (rule multiset_eqI, unfold count)
      fix c
      show "count_of ((a, n) # ys) c =
        count (((+) {#a#} ^^ n) (Abs_multiset (count_of ys))) c" (is "?l = ?r")
      proof (cases "c = a")
        case False
        then show ?thesis
          unfolding cs by (induct n) auto
      next
        case True
        then have "?l = n" by (simp add: cs)
        also have "n = ?r" unfolding True
        proof (induct n)
          case 0
          from Cons(2)[unfolded pair] have "a  fst ` set ys" by auto
          then show ?case by (induct ys) (simp, auto simp: cs)
        next
          case Suc
          then show ?case by simp
        qed
        finally show ?thesis .
      qed
    qed
    show ?case
      unfolding pair
      apply (simp add: IH[symmetric])
      unfolding id Ys_def[symmetric]
      apply (induct n)
      apply (auto simp: fold_mset_fun_left_comm[symmetric])
      done
  qed
qed

end

context
begin

private lift_definition single_alist_entry :: "'a  'b  ('a, 'b) alist" is "λa b. [(a, b)]"
  by auto

lemma image_mset_Bag [code]:
  "image_mset f (Bag ms) =
    DAList_Multiset.fold (λa n m. Bag (single_alist_entry (f a) n) + m) {#} ms"
  unfolding image_mset_def
proof (rule comp_fun_commute.DAList_Multiset_fold, unfold_locales, (auto simp: ac_simps)[1])
  fix a n m
  show "Bag (single_alist_entry (f a) n) + m = ((add_mset  f) a ^^ n) m" (is "?l = ?r")
  proof (rule multiset_eqI)
    fix x
    have "count ?r x = (if x = f a then n + count m x else count m x)"
      by (induct n) auto
    also have " = count ?l x"
      by (simp add: single_alist_entry.rep_eq)
    finally show "count ?l x = count ?r x" ..
  qed
qed

end

― ‹we cannot use λa n. (+) (a * n)› for folding, since (*)› is not defined in comm_monoid_add›
lemma sum_mset_Bag[code]: "sum_mset (Bag ms) = DAList_Multiset.fold (λa n. (((+) a) ^^ n)) 0 ms"
  unfolding sum_mset.eq_fold
  apply (rule comp_fun_commute.DAList_Multiset_fold)
  apply unfold_locales
  apply (auto simp: ac_simps)
  done

― ‹we cannot use λa n. (*) (a ^ n)› for folding, since (^)› is not defined in comm_monoid_mult›
lemma prod_mset_Bag[code]: "prod_mset (Bag ms) = DAList_Multiset.fold (λa n. (((*) a) ^^ n)) 1 ms"
  unfolding prod_mset.eq_fold
  apply (rule comp_fun_commute.DAList_Multiset_fold)
  apply unfold_locales
  apply (auto simp: ac_simps)
  done

lemma size_fold: "size A = fold_mset (λ_. Suc) 0 A" (is "_ = fold_mset ?f _ _")
proof -
  interpret comp_fun_commute ?f by standard auto
  show ?thesis by (induct A) auto
qed

lemma size_Bag[code]: "size (Bag ms) = DAList_Multiset.fold (λa n. (+) n) 0 ms"
  unfolding size_fold
proof (rule comp_fun_commute.DAList_Multiset_fold, unfold_locales, simp)
  fix a n x
  show "n + x = (Suc ^^ n) x"
    by (induct n) auto
qed


lemma set_mset_fold: "set_mset A = fold_mset insert {} A" (is "_ = fold_mset ?f _ _")
proof -
  interpret comp_fun_commute ?f by standard auto
  show ?thesis by (induct A) auto
qed

lemma set_mset_Bag[code]:
  "set_mset (Bag ms) = DAList_Multiset.fold (λa n. (if n = 0 then (λm. m) else insert a)) {} ms"
  unfolding set_mset_fold
proof (rule comp_fun_commute.DAList_Multiset_fold, unfold_locales, (auto simp: ac_simps)[1])
  fix a n x
  show "(if n = 0 then λm. m else insert a) x = (insert a ^^ n) x" (is "?l n = ?r n")
  proof (cases n)
    case 0
    then show ?thesis by simp
  next
    case (Suc m)
    then have "?l n = insert a x" by simp
    moreover have "?r n = insert a x" unfolding Suc by (induct m) auto
    ultimately show ?thesis by auto
  qed
qed


instantiation multiset :: (exhaustive) exhaustive
begin

definition exhaustive_multiset ::
  "('a multiset  (bool × term list) option)  natural  (bool × term list) option"
  where "exhaustive_multiset f i = Quickcheck_Exhaustive.exhaustive (λxs. f (Bag xs)) i"

instance ..

end

end