File ‹approximation_generator.ML›

(*  Title:      HOL/Decision_Procs/approximation_generator.ML
    Author:     Fabian Immler, TU Muenchen
*)

signature APPROXIMATION_GENERATOR =
sig
  val custom_seed: int Config.T
  val precision: int Config.T
  val epsilon: real Config.T
  val approximation_generator:
    Proof.context ->
    (term * term list) list ->
    bool -> int list -> (bool * term list) option * Quickcheck.report option
  val setup: theory -> theory
end;

structure Approximation_Generator : APPROXIMATION_GENERATOR =
struct

val custom_seed = Attrib.setup_config_int bindingquickcheck_approximation_custom_seed (K ~1)

val precision = Attrib.setup_config_int bindingquickcheck_approximation_precision (K 30)

val epsilon = Attrib.setup_config_real bindingquickcheck_approximation_epsilon (K 0.0)

val random_float = @{code "random_class.random::_  _  (float × (unit  term)) × _"}

fun nat_of_term t =
  (HOLogic.dest_nat t handle TERM _ => snd (HOLogic.dest_number t)
    handle TERM _ => raise TERM ("nat_of_term", [t]));

fun int_of_term t = snd (HOLogic.dest_number t) handle TERM _ => raise TERM ("int_of_term", [t]);

fun real_of_man_exp m e = Real.fromManExp {man = Real.fromInt m, exp = e}

fun mapprox_float Const_Float for m e = real_of_man_exp (int_of_term m) (int_of_term e)
  | mapprox_float t = Real.fromInt (snd (HOLogic.dest_number t))
      handle TERM _ => raise TERM ("mapprox_float", [t]);

(* TODO: define using compiled terms? *)
fun mapprox_floatarith Const_Add for a b xs = mapprox_floatarith a xs + mapprox_floatarith b xs
  | mapprox_floatarith Const_Minus for a xs = ~ (mapprox_floatarith a xs)
  | mapprox_floatarith Const_Mult for a b xs = mapprox_floatarith a xs * mapprox_floatarith b xs
  | mapprox_floatarith Const_Inverse for a xs = 1.0 / mapprox_floatarith a xs
  | mapprox_floatarith Const_Cos for a xs = Math.cos (mapprox_floatarith a xs)
  | mapprox_floatarith Const_Arctan for a xs = Math.atan (mapprox_floatarith a xs)
  | mapprox_floatarith Const_Abs for a xs = abs (mapprox_floatarith a xs)
  | mapprox_floatarith Const_Max for a b xs =
      Real.max (mapprox_floatarith a xs, mapprox_floatarith b xs)
  | mapprox_floatarith Const_Min for a b xs =
      Real.min (mapprox_floatarith a xs, mapprox_floatarith b xs)
  | mapprox_floatarith Const_Pi _ = Math.pi
  | mapprox_floatarith Const_Sqrt for a xs = Math.sqrt (mapprox_floatarith a xs)
  | mapprox_floatarith Const_Exp for a xs = Math.exp (mapprox_floatarith a xs)
  | mapprox_floatarith Const_Powr for a b xs =
      Math.pow (mapprox_floatarith a xs, mapprox_floatarith b xs)
  | mapprox_floatarith Const_Ln for a xs = Math.ln (mapprox_floatarith a xs)
  | mapprox_floatarith Const_Power for a n xs =
      Math.pow (mapprox_floatarith a xs, Real.fromInt (nat_of_term n))
  | mapprox_floatarith Const_Floor for a xs = Real.fromInt (floor (mapprox_floatarith a xs))
  | mapprox_floatarith Const_Var for n xs = nth xs (nat_of_term n)
  | mapprox_floatarith Const_Num for m _ = mapprox_float m
  | mapprox_floatarith t _ = raise TERM ("mapprox_floatarith", [t])

fun mapprox_atLeastAtMost eps x a b xs =
    let
      val x' = mapprox_floatarith x xs
    in
      mapprox_floatarith a xs + eps <= x' andalso x' + eps <= mapprox_floatarith b xs
    end

fun mapprox_form eps Const_Bound for x a b f xs =
    (not (mapprox_atLeastAtMost eps x a b xs)) orelse mapprox_form eps f xs
| mapprox_form eps Const_Assign for x a f xs =
    (Real.!= (mapprox_floatarith x xs, mapprox_floatarith a xs)) orelse mapprox_form eps f xs
| mapprox_form eps Const_Less for a b xs = mapprox_floatarith a xs + eps < mapprox_floatarith b xs
| mapprox_form eps Const_LessEqual for a b xs = mapprox_floatarith a xs + eps <= mapprox_floatarith b xs
| mapprox_form eps Const_AtLeastAtMost for x a b xs = mapprox_atLeastAtMost eps x a b xs
| mapprox_form eps Const_Conj for f g xs = mapprox_form eps f xs andalso mapprox_form eps g xs
| mapprox_form eps Const_Disj for f g xs = mapprox_form eps f xs orelse mapprox_form eps g xs
| mapprox_form _ t _ = raise TERM ("mapprox_form", [t])

fun dest_interpret_form Const_interpret_form for b xs = (b, xs)
  | dest_interpret_form t = raise TERM ("dest_interpret_form", [t])

fun random_float_list size xs seed =
  fold (K (apsnd (random_float size) #-> (fn c => apfst (fn b => b::c)))) xs ([],seed)

fun real_of_Float (@{code Float} (m, e)) =
    real_of_man_exp (@{code integer_of_int} m) (@{code integer_of_int} e)

fun term_of_int i = (HOLogic.mk_number @{typ int} (@{code integer_of_int} i))

fun term_of_Float (@{code Float} (m, e)) = @{term Float} $ term_of_int m $ term_of_int e

fun is_True Const_True = true
  | is_True _ = false

val postproc_form_eqs =
  @{lemma
    "real_of_float (Float 0 a) = 0"
    "real_of_float (Float (numeral m) 0) = numeral m"
    "real_of_float (Float 1 0) = 1"
    "real_of_float (Float (- 1) 0) = - 1"
    "real_of_float (Float 1 (numeral e)) = 2 ^ numeral e"
    "real_of_float (Float 1 (- numeral e)) = 1 / 2 ^ numeral e"
    "real_of_float (Float a 1) = a * 2"
    "real_of_float (Float a (-1)) = a / 2"
    "real_of_float (Float (- a) b) = - real_of_float (Float a b)"
    "real_of_float (Float (numeral m) (numeral e)) = numeral m * 2 ^ (numeral e)"
    "real_of_float (Float (numeral m) (- numeral e)) = numeral m / 2 ^ (numeral e)"
    "- (c * d::real) = -c * d"
    "- (c / d::real) = -c / d"
    "- (0::real) = 0"
    "int_of_integer (numeral k) = numeral k"
    "int_of_integer (- numeral k) = - numeral k"
    "int_of_integer 0 = 0"
    "int_of_integer 1 = 1"
    "int_of_integer (- 1) = - 1"
    by auto
  }

fun rewrite_with ctxt thms = Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps thms)
fun conv_term ctxt conv r = Thm.cterm_of ctxt r |> conv |> Thm.prop_of |> Logic.dest_equals |> snd

fun approx_random ctxt prec eps frees e xs genuine_only size seed =
  let
    val (rs, seed') = random_float_list size xs seed
    fun mk_approx_form e ts =
      Constapprox_form $
        HOLogic.mk_number Typenat prec $
        e $
        (HOLogic.mk_list typfloat interval option
          (map (fn t => ConstSome typfloat interval $ Constinterval_of Typefloat for t) ts)) $
        ConstNil Typenat
  in
    (if
      mapprox_form eps e (map (real_of_Float o fst) rs)
      handle
        General.Overflow => false
      | General.Domain => false
      | General.Div => false
      | IEEEReal.Unordered => false
    then
      let
        val ts = map (fn x => term_of_Float (fst x)) rs
        val ts' = map
          (AList.lookup op = (map dest_Free xs ~~ ts)
            #> the_default Term.dummy
            #> curry op $ Constreal_of_float
            #> conv_term ctxt (rewrite_with ctxt postproc_form_eqs))
          frees
      in
        if Approximation.approximate ctxt (mk_approx_form e ts) |> is_True
        then SOME (true, ts')
        else (if genuine_only then NONE else SOME (false, ts'))
      end
    else NONE, seed')
  end

val preproc_form_eqs =
  @{lemma
    "(a::real)  {b .. c}  b  a  a  c"
    "a = b  a  b  b  a"
    "(p  q)  ¬p  q"
    "(p  q)  (p  q)  (q  p)"
    by auto}

fun reify_goal ctxt t =
  HOLogic.mk_not t
    |> conv_term ctxt (rewrite_with ctxt preproc_form_eqs)
    |> Approximation.reify_form ctxt
    |> dest_interpret_form
    ||> HOLogic.dest_list

fun approximation_generator_raw ctxt t =
  let
    val iterations = Config.get ctxt Quickcheck.iterations
    val prec = Config.get ctxt precision
    val eps = Config.get ctxt epsilon
    val cs = Config.get ctxt custom_seed
    val seed = (Code_Numeral.natural_of_integer (cs + 1), Code_Numeral.natural_of_integer 1)
    val run = if cs < 0
      then (fn f => fn seed => (Random_Engine.run f, seed))
      else (fn f => fn seed => f seed)
    val frees = Term.add_frees t []
    val (e, xs) = reify_goal ctxt t
    fun single_tester b s =
      approx_random ctxt prec eps frees e xs b s |> run
    fun iterate _ _ 0 _ = NONE
      | iterate genuine_only size j seed =
        case single_tester genuine_only size seed of
          (NONE, seed') => iterate genuine_only size (j - 1) seed'
        | (SOME q, _) => SOME q
  in
    fn genuine_only => fn size => (iterate genuine_only size iterations seed, NONE)
  end

fun approximation_generator ctxt [(t, _)] =
  (fn genuine_only =>
    fn [_, size] =>
      approximation_generator_raw ctxt t genuine_only
        (Code_Numeral.natural_of_integer size))
  | approximation_generator _ _ =
      error "Quickcheck-approximation does not support type variables (or finite instantiations)"

val test_goals =
  Quickcheck_Common.generator_test_goal_terms
    ("approximation", (fn _ => fn _ => false, approximation_generator))

val active = Attrib.setup_config_bool bindingquickcheck_approximation_active (K false)

val setup = Context.theory_map (Quickcheck.add_tester ("approximation", (active, test_goals)))

end