Theory Mod_Exp
section ‹Fast modular exponentiation›
theory Mod_Exp
imports Cong "HOL-Library.Power_By_Squaring"
begin
context euclidean_semiring_cancel
begin
definition mod_exp_aux :: "'a ⇒ 'a ⇒ 'a ⇒ nat ⇒ 'a"
where "mod_exp_aux m = efficient_funpow (λx y. x * y mod m)"
lemma mod_exp_aux_code [code]:
"mod_exp_aux m y x n =
(if n = 0 then y
else if n = 1 then (x * y) mod m
else if even n then mod_exp_aux m y ((x * x) mod m) (n div 2)
else mod_exp_aux m ((x * y) mod m) ((x * x) mod m) (n div 2))"
unfolding mod_exp_aux_def by (rule efficient_funpow_code)
lemma mod_exp_aux_correct:
"mod_exp_aux m y x n mod m = (x ^ n * y) mod m"
proof -
have "mod_exp_aux m y x n = efficient_funpow (λx y. x * y mod m) y x n"
by (simp add: mod_exp_aux_def)
also have "… = ((λy. x * y mod m) ^^ n) y"
by (rule efficient_funpow_correct) (simp add: mod_mult_left_eq mod_mult_right_eq mult_ac)
also have "((λy. x * y mod m) ^^ n) y mod m = (x ^ n * y) mod m"
proof (induction n)
case (Suc n)
hence "x * ((λy. x * y mod m) ^^ n) y mod m = x * x ^ n * y mod m"
by (metis mod_mult_right_eq mult.assoc)
thus ?case by auto
qed auto
finally show ?thesis .
qed
definition mod_exp :: "'a ⇒ nat ⇒ 'a ⇒ 'a"
where "mod_exp b e m = (b ^ e) mod m"
lemma mod_exp_code [code]: "mod_exp b e m = mod_exp_aux m 1 b e mod m"
by (simp add: mod_exp_def mod_exp_aux_correct)
end
lemmas [code_abbrev] = mod_exp_def[where ?'a = nat] mod_exp_def[where ?'a = int]
lemma cong_power_nat_code [code_unfold]:
"[b ^ e = (x ::nat)] (mod m) ⟷ mod_exp b e m = x mod m"
by (simp add: mod_exp_def cong_def)
lemma cong_power_int_code [code_unfold]:
"[b ^ e = (x ::int)] (mod m) ⟷ mod_exp b e m = x mod m"
by (simp add: mod_exp_def cong_def)
text ‹
The following rules allow the simplifier to evaluate @{const mod_exp} efficiently.
›
lemma eval_mod_exp_aux [simp]:
"mod_exp_aux m y x 0 = y"
"mod_exp_aux m y x (Suc 0) = (x * y) mod m"
"mod_exp_aux m y x (numeral (num.Bit0 n)) =
mod_exp_aux m y (x⇧2 mod m) (numeral n)"
"mod_exp_aux m y x (numeral (num.Bit1 n)) =
mod_exp_aux m ((x * y) mod m) (x⇧2 mod m) (numeral n)"
proof -
define n' where "n' = (numeral n :: nat)"
have [simp]: "n' ≠ 0" by (auto simp: n'_def)
show "mod_exp_aux m y x 0 = y" and "mod_exp_aux m y x (Suc 0) = (x * y) mod m"
by (simp_all add: mod_exp_aux_def)
have "numeral (num.Bit0 n) = (2 * n')"
by (subst numeral.numeral_Bit0) (simp del: arith_simps add: n'_def)
also have "mod_exp_aux m y x … = mod_exp_aux m y (x^2 mod m) n'"
by (subst mod_exp_aux_code) (simp_all add: power2_eq_square)
finally show "mod_exp_aux m y x (numeral (num.Bit0 n)) =
mod_exp_aux m y (x⇧2 mod m) (numeral n)"
by (simp add: n'_def)
have "numeral (num.Bit1 n) = Suc (2 * n')"
by (subst numeral.numeral_Bit1) (simp del: arith_simps add: n'_def)
also have "mod_exp_aux m y x … = mod_exp_aux m ((x * y) mod m) (x^2 mod m) n'"
by (subst mod_exp_aux_code) (simp_all add: power2_eq_square)
finally show "mod_exp_aux m y x (numeral (num.Bit1 n)) =
mod_exp_aux m ((x * y) mod m) (x⇧2 mod m) (numeral n)"
by (simp add: n'_def)
qed
lemma eval_mod_exp [simp]:
"mod_exp b' 0 m' = 1 mod m'"
"mod_exp b' 1 m' = b' mod m'"
"mod_exp b' (Suc 0) m' = b' mod m'"
"mod_exp b' e' 0 = b' ^ e'"
"mod_exp b' e' 1 = 0"
"mod_exp b' e' (Suc 0) = 0"
"mod_exp 0 1 m' = 0"
"mod_exp 0 (Suc 0) m' = 0"
"mod_exp 0 (numeral e) m' = 0"
"mod_exp 1 e' m' = 1 mod m'"
"mod_exp (Suc 0) e' m' = 1 mod m'"
"mod_exp (numeral b) (numeral e) (numeral m) =
mod_exp_aux (numeral m) 1 (numeral b) (numeral e) mod numeral m"
by (simp_all add: mod_exp_def mod_exp_aux_correct)
end