(* =========================================================== *)
(* Main formal verification functions                          *)
(* Author: Alexey Solovyev                                     *)
(* Date: 2012-10-27                                            *)
(* =========================================================== *)

needs "verifier/interval_m/verifier.ml";;
needs "verifier/m_verifier.hl";;
needs "verifier/m_verifier_build.hl";;
needs "taylor/m_taylor_arith2.hl";;
needs "misc/vars.hl";;

#load "unix.cma";;

module M_verifier_main = struct

open Arith_misc;;
open Interval_arith;;
open Eval_interval;;
open More_float;;
open M_verifier;;
open M_verifier_build;;
open M_taylor;;
open M_taylor_arith2;;
open Taylor;;
open Misc_vars;;
open Verifier_options;;



(* Parameters *)
type verification_parameters =
{
  (* If true, then monotonicity properties can be used *)
  (* to reduce the dimension of a problem *)
  allow_derivatives : bool;
  (* If true, then convexity can be used *)
  (* to reduce the dimension of a problem *)
  convex_flag : bool;
  (* If true, then verification on internal subdomains can be skipped *)
  (* for a monotone function *)
  mono_pass_flag : bool;
  (* If true, then raw interval arithmetic can be used *)
  (* (without Taylor approximations) *)
  raw_intervals_flag : bool;
  (* If true, then an informal procedure is used to determine *)
  (* the optimal precision for the formal verification *)
  adaptive_precision : bool;
  (* This parameter might be used in cases when the certificate search *)
  (* procedure returns a wrong result due to rounding errors *)
  (* (this parameter will be eliminated when the search procedure is corrected) *)
  eps : float;
};;

let default_params =
{
  allow_derivatives = true;
  convex_flag = true;
  mono_pass_flag = true;
  raw_intervals_flag = true;
  adaptive_precision = true;
  eps = 0.0;
};;


type verification_stats =
{
  total_time : float;
  formal_verification_time : float;
  certificate : Verifier.certificate_stats;
};;


(********************************)

(* Adds a constant approximation to the table of known constants *)
let add_constant_interval int_th =
  Eval_interval.add_constant_interval int_th;
  Informal_eval_interval.add_constant_interval int_th;;


(* Tests if an expression has only given binary and unary operations *)
let test_expression bin_ops unary_ops =
  let rec test =
    (* Tests if the expression is in the form `a$i` *)
    let test_vector tm =
      let var, index = dest_binary "$" tm in
	dest_var var, dest_small_numeral index in
      (* Tests if the expression is a valid binary operation *)
    let test_binary tm =
	try
	  let lhs, rhs = dest_comb tm in
	  let op, lhs = dest_comb lhs in
	  let c, _ = dest_const op in
	    if mem c bin_ops then (test lhs && test rhs) else false
	with Failure _ -> false in
      (* Tests if the expression is a valid unary operation *)
    let test_unary tm =
      try
	let lhs, rhs = dest_comb tm in
	let c, _ = dest_const lhs in
	  if mem c unary_ops then test rhs else false
      with Failure _ -> false in

      fun tm ->
	frees tm = [] or
	  can dest_var tm or
	  can test_vector tm or
	  test_unary tm or
	  test_binary tm in
    test;;


(* Tests if the given expression is a polynomial expression *)
let is_poly =
  let bin_ops = ["real_add"; "real_mul"; "real_sub"; "real_pow"] in
  let unary_ops = ["real_neg"] in
    test_expression bin_ops unary_ops;;


(**********************************)

(* Creates basic verification functions *)
let rec mk_funs =
  (* add *)
  let mk_add n (f1, tf1, ti1) (f2, tf2, ti2) =
    (fun p1 p2 x -> 
       let a = f1 p1 p2 x and
	   b = f2 p1 p2 x in
	 eval_m_taylor_add2 n p1 p2 a b),
    Plus (tf1, tf2),
    (fun p1 p2 x ->
       let a = ti1 p1 p2 x and
	   b = ti2 p1 p2 x in
	 Informal_taylor.eval_m_taylor_add p1 p2 a b) in
    (* sub *)
  let mk_sub n (f1, tf1, ti1) (f2, tf2, ti2) =
    let neg_one = Interval.mk_interval(-1.0, -1.0) in
      (fun p1 p2 x -> 
	 let a = f1 p1 p2 x and
	     b = f2 p1 p2 x in
	   eval_m_taylor_sub2 n p1 p2 a b),
    Plus (tf1, Scale(tf2, neg_one)),
    (fun p1 p2 x ->
       let a = ti1 p1 p2 x and
	   b = ti2 p1 p2 x in
	 Informal_taylor.eval_m_taylor_sub p1 p2 a b) in
    (* mul *)
  let mk_mul n (f1, tf1, ti1) (f2, tf2, ti2) =
    (fun p1 p2 x -> 
       let a = f1 p1 p2 x and
	   b = f2 p1 p2 x in
	 eval_m_taylor_mul2 n p1 p2 a b),
    Product (tf1, tf2),
    (fun p1 p2 x ->
       let a = ti1 p1 p2 x and
	   b = ti2 p1 p2 x in
	 Informal_taylor.eval_m_taylor_mul p1 p2 a b) in
    (* neg *)
  let mk_neg n (f1, tf1, ti1) =
    (fun p1 p2 x -> 
       let a = f1 p1 p2 x in
	 eval_m_taylor_neg2 n a),
    Scale (tf1, Interval.mk_interval (-1.0, -1.0)),
    (fun p1 p2 x ->
       let a = ti1 p1 p2 x in
	 Informal_taylor.eval_m_taylor_neg a) in
    (* sqrt *)
  let mk_sqrt n (f1, tf1, ti1) =
    (fun p1 p2 x -> 
       let a = f1 p1 p2 x in
	 eval_m_taylor_sqrt2 n p1 p2 a),
    Uni_compose (Univariate.usqrt, tf1),
    (fun p1 p2 x ->
       let a = ti1 p1 p2 x in
	 Informal_taylor.eval_m_taylor_sqrt p1 p2 a) in
    (* inv *)
  let mk_inv n (f1, tf1, ti1) =
    (fun p1 p2 x -> 
       let a = f1 p1 p2 x in
	 eval_m_taylor_inv2 n p1 p2 a),
    Uni_compose (Univariate.uinv, tf1),
    (fun p1 p2 x ->
       let a = ti1 p1 p2 x in
	 Informal_taylor.eval_m_taylor_inv p1 p2 a) in
    (* atn *)
  let mk_atn n (f1, tf1, ti1) =
    (fun p1 p2 x -> 
       let a = f1 p1 p2 x in
	 eval_m_taylor_atn2 n p1 p2 a),
    Uni_compose (Univariate.uatan, tf1),
    (fun p1 p2 x ->
       let a = ti1 p1 p2 x in
	 Informal_taylor.eval_m_taylor_atn p1 p2 a) in
    (* acs *)
  let mk_acs n (f1, tf1, ti1) =
    (fun p1 p2 x -> 
       let a = f1 p1 p2 x in
	 eval_m_taylor_acs2 n p1 p2 a),
    Uni_compose (Univariate.uacos, tf1),
    (fun p1 p2 x ->
       let a = ti1 p1 p2 x in
	 Informal_taylor.eval_m_taylor_acs p1 p2 a) in
    (* binary operations *)	   
  let bin_ops = 
    ["real_add", mk_add;
     "real_sub", mk_sub;
     "real_mul", mk_mul] in
    (* unary operations *)
  let unary_ops =
    ["real_neg", mk_neg;
     "sqrt", mk_sqrt;
     "atn", mk_atn;
     "acs", mk_acs;
     "real_inv", mk_inv] in
    (* makes a binary operation *)
  let mk_bin n pp x_var tm =
    let lhs, rhs = dest_comb tm in
    let op, lhs = dest_comb lhs in
    let mk_f = assoc ((fst o dest_const) op) bin_ops in
    let l_funs = mk_funs n pp (mk_abs(x_var, lhs)) and
	r_funs = mk_funs n pp (mk_abs(x_var, rhs)) in
      mk_f n l_funs r_funs in
    (* makes an unary operation *)
  let mk_unary n pp x_var tm =
    let op, rhs = dest_comb tm in
    let mk_f = assoc ((fst o dest_const) op) unary_ops in
    let funs = mk_funs n pp (mk_abs(x_var, rhs)) in
      mk_f n funs in
    (* the main function *)
    fun n pp fun_tm ->
      let x_var, body_tm = dest_abs fun_tm in
	if is_poly body_tm then
	  let eval_fs, tf, ti = mk_verification_functions_poly pp fun_tm in
	    eval_fs.taylor, tf, ti.Informal_verifier.taylor
	else
	  try mk_bin n pp x_var body_tm with Failure _ ->
	    mk_unary n pp x_var body_tm;;

	   
(* Prepares verification functions *)
(* fun_tm must be in the form `\x. f x` *)
let mk_verification_functions =
  let dummy_f pp lo hi = failwith "dummy f" and
      dummy_df i pp lo hi = failwith "dummy df" and
      dummy_ddf i j pp lo hi = failwith "dummy ddf" and
      dummy_diff2 lo hi = failwith "dummy diff2" in
    fun params pp fun_tm ->
      let x_var, body_tm = dest_abs fun_tm in
	if is_poly body_tm then
	  mk_verification_functions_poly pp fun_tm
	else
	  let n = get_dim x_var in
	  let eval_taylor, tf, eval_ti = mk_funs n pp fun_tm in
	  let _ = params := {!params with raw_intervals_flag = false; convex_flag = false} in
	    {taylor = eval_taylor; 
	     f = dummy_f; df = dummy_df; ddf = dummy_ddf; diff2_f = dummy_diff2}, tf,
      {Informal_verifier.taylor = eval_ti;
       Informal_verifier.f = dummy_f;
       Informal_verifier.df = dummy_df;
       Informal_verifier.ddf = dummy_ddf};;

   
(********************************)

let convert_to_float_list pp lo_flag list_tm =
  let tms = dest_list list_tm in
  let i_funs = map build_interval_fun tms in
  let ints = map (fun f -> eval_interval_fun pp f [] []) i_funs in
  let extract = (if lo_flag then fst else snd) o dest_pair o rand o concl in
    mk_list (map extract ints, real_ty);;


(* Creates a theorem |- interval[xx_tm, zz_tm] SUBSET interval[float(xx_tm), float(zz_tm)]
   and two lists: float(xx_tm) and float(zz_tm) *)
let mk_float_domain pp (xx_tm, zz_tm) =
  let xx_list = dest_list xx_tm and
      zz_list = dest_list zz_tm in
  let n = length xx_list in 
  let get_intervals tms =
    let i_funs = map build_interval_fun tms in
      map (fun f -> eval_interval_fun pp f [] []) i_funs in
  let xx_ints = get_intervals xx_list and
      zz_ints = get_intervals zz_list in
  let xx_ineqs = map (CONJUNCT1 o ONCE_REWRITE_RULE[interval_arith]) xx_ints and
      zz_ineqs = map (CONJUNCT2 o ONCE_REWRITE_RULE[interval_arith]) zz_ints in
  let a_vals = map (lhand o concl) xx_ineqs and
      b_vals = map (rand o concl) zz_ineqs in
  let a_vars = mk_real_vars n "a" and
      b_vars = mk_real_vars n "b" and
      c_vars = mk_real_vars n "c" and
      d_vars = mk_real_vars n "d" in
  let th0 = (INST (zip xx_list c_vars) o INST (zip zz_list d_vars) o
	       INST (zip a_vals a_vars) o INST (zip b_vals b_vars)) 
    subset_interval_thms_array.(n) in
    itlist MY_PROVE_HYP (xx_ineqs @ zz_ineqs) th0, 
  (mk_list (a_vals, real_ty), mk_list (b_vals, real_ty));;



(* Given a term a < b, returns the theorem |- a - b < &0 <=> a < b *)
(* Also, deals with > and / *)
(* A user can provide additional rewrite theorems *)
let mk_standard_ineq =
  let lemma = REAL_ARITH `a < b <=> a - b < &0` in
    fun thms tm ->
      let th0 = (REWRITE_CONV([real_gt; real_div] @ thms) THENC DEPTH_CONV let_CONV) tm in
      let rhs = rand (concl th0) in
      let th1 = (ONCE_REWRITE_CONV[lemma] THENC PURE_REWRITE_CONV[REAL_NEG_0; REAL_SUB_RZERO; REAL_SUB_LZERO]) rhs in
	TRANS th0 th1;;


(* Converts a term in the form `x + y` into the term `\x:real^2. x$1 + x$2` *)
let expr_to_vector_fun =
  let comp_op = `$` in
    fun expr_tm ->
      let vars = List.sort Pervasives.compare (frees expr_tm) in
      let n = length vars in
      let x_var = mk_var ("x", n_vector_type_array.(if n = 0 then 1 else n)) in
      let x_tm = mk_icomb (comp_op, x_var) in
      let vars2 = map (fun i -> mk_comb (x_tm, mk_small_numeral i)) (1--n) in
	mk_abs (x_var, subst (zip vars2 vars) expr_tm), 
      (if n = 0 then mk_vector_list [x_var] else mk_vector_list vars);;


(* Given an inequality `P x y`, variable names and the corresponding bounds,
   yields `(x0 <= x /\ x <= x1) /\ (y0 <= y /\ y <= y1) ==> P x y` *)
let mk_ineq ineq_tm names dom_tm =
  let lo_list = dest_list (fst dom_tm) and
      hi_list = dest_list (snd dom_tm) in
  let vars = map (fun name -> mk_var (name, real_ty)) names in
  let lo_ineqs = map2 (fun tm1 tm2 -> mk_binop le_op_real tm1 tm2) lo_list vars and
      hi_ineqs = map2 (fun tm1 tm2 -> mk_binop le_op_real tm1 tm2) vars hi_list in
  let ineqs = map2 (fun tm1 tm2 -> mk_conj (tm1, tm2)) lo_ineqs hi_ineqs in
  let cond = end_itlist (curry mk_conj) ineqs in
    mk_imp (cond, ineq_tm);;


(* Reverts the effect of mk_ineq function *)
let dest_ineq ineq_tm =
  if frees ineq_tm = [] then
    ineq_tm, [], (real_empty_list, real_empty_list)
  else
    let tm0 = (rand o concl o PURE_REWRITE_CONV[IMP_IMP; GSYM CONJ_ASSOC]) ineq_tm in
    let cond, ineq = dest_imp tm0 in
    let conds = striplist dest_conj cond in
    let ineqs = ref [] in
    let decode_ineq tm =
      let lhs, rhs = dest_binop le_op_real tm in
      let lo_flag = (frees lhs = []) in
      let name = (fst o dest_var) (if lo_flag then rhs else lhs) in
      let val_ref = 
	try assoc name !ineqs
	with Failure _ -> 
	  let val_ref = ref (x_var_real, x_var_real) in 
	    ineqs := ((name, val_ref) :: !ineqs); val_ref in
	val_ref := if lo_flag then (lhs, snd !val_ref) else (fst !val_ref, rhs) in
    let _ = map (fun tm -> 
		   (try decode_ineq tm with Failure _ ->
		      failwith ("Bad variable bound inequality: "^string_of_term tm))) conds in
    let names, bounds0 = unzip !ineqs in
    let lo, hi = unzip (map (fun r -> !r) bounds0) in
    let test_bounds bounds bound_name =
      let _ = map2 (fun tm name -> if frees tm <> [] then 
		      failwith (bound_name^" bound is not defined for "^name) else ())
	bounds names in () in
    let _ = test_bounds hi "Upper"; test_bounds lo "Lower" in
      ineq, names, (mk_real_list lo, mk_real_list hi);;

(*********************************)

(* Normalizes a verification result *)
let normalize_result norm_flag v1 eq_th1 domain_sub_th pass_thm =
  let th0 = REWRITE_RULE[m_cell_pass] pass_thm in
  let n = (get_dim o fst o dest_forall o concl) th0 in
  let th1 = SPEC v1 th0 in
  let comp_thms = end_itlist CONJ (Array.to_list comp_thms_array.(n)) in
  let th2 = REWRITE_RULE[comp_thms] th1 in
  let th3 = (UNDISCH_ALL o REWRITE_RULE[GSYM eq_th1]) th2 in
  let dom_th = (UNDISCH_ALL o SPEC v1 o REWRITE_RULE[SUBSET]) domain_sub_th in
  let th4 = (DISCH_ALL o MY_PROVE_HYP dom_th) th3 in
  let th5 = REWRITE_RULE[IN_INTERVAL; dimindex_array.(n); gen_in_interval n; comp_thms] th4 in
    if norm_flag then GEN_ALL th5 else th4;;


(* Verifies the given inequality *)
(* Returns the final theorem and verification statistics *)
let verify_ineq0 params0 norm_flag pp ineq_tm var_names (lo_tm, hi_tm) rewrite_thms =
  let total_start = Unix.gettimeofday() in
  let eq_th1 = mk_standard_ineq rewrite_thms ineq_tm in
  let ineq_tm1 = (lhand o rand o concl) eq_th1 in
    if frees ineq_tm1 = [] then
      let i_fun = build_interval_fun ineq_tm1 in
      let th0 = eval_interval_fun pp i_fun [] [] in
      let th1 = float_interval_lt0 th0 in
      let total = Unix.gettimeofday() -. total_start in
	REWRITE_RULE[GSYM eq_th1] th1, 
  {total_time = total; formal_verification_time = total; certificate = Verifier.dummy_stats}
    else
      let fun_tm, v1 = expr_to_vector_fun ineq_tm1 in
      let vars = map (fst o dest_var) (dest_vector v1) in
      let lo_list = dest_list lo_tm and
	  hi_list = dest_list hi_tm in
      let bounds0 = zip var_names (zip lo_list hi_list) in
      let bounds = itlist (fun name list -> assoc name bounds0 :: list) vars [] in
      let xx, zz = unzip bounds in
      let xx, zz = mk_real_list xx, mk_real_list zz in
	
      let domain_sub_th, (xx1, zz1) = mk_float_domain pp (xx, zz) in
      let n = (get_dim o fst o dest_abs) fun_tm in
      let xx2 = Informal_taylor.convert_to_float_list pp true xx and
	  zz2 = Informal_taylor.convert_to_float_list pp false zz in
      let xx_float = map float_of_float_tm (dest_list xx1) and
	  zz_float = map float_of_float_tm (dest_list zz1) in
	
      let params = ref params0 in
      let eval_fs, tf, ti = mk_verification_functions params pp fun_tm in
	
      let _ = !info_print_level < 1 or (report0 "Constructing a solution certificate... "; true) in
      let certificate = Verifier.run_test tf xx_float zz_float false 0.0 
	!params.allow_derivatives !params.convex_flag !params.mono_pass_flag
	!params.raw_intervals_flag !params.eps in
      let stats = Verifier.result_stats certificate in
      let _ = !info_print_level < 1 or (report0 " done\n"; true) in
      let _ = !info_print_level < 1 or (Verifier.report_stats stats; true) in

      let c1 = Verifier.transform_result xx_float zz_float certificate in
      let start, finish, result = 
	if !params.adaptive_precision then
	  let _ = !info_print_level < 1 or (report0 "Informal verification... "; true) in
	  let c1p = Informal_verifier.m_verify_list pp 1 pp ti c1 xx2 zz2 in
	  let _ = !info_print_level < 1 or (report0 " done\n"; true) in

	  let _ = !info_print_level < 1 or (report0 "Formal verification... "; true) in
	  let start = Unix.gettimeofday() in
	  let result = m_p_verify_list n pp eval_fs c1p xx1 zz1 in
	  let finish = Unix.gettimeofday() in
	  let _ = !info_print_level < 1 or (report0 " done\n"; true) in
	    start, finish, result
	else
	  let _ = !info_print_level < 1 or (report0 "Formal verification... "; true) in
	  let start = Unix.gettimeofday() in
	  let result = m_verify_list n pp eval_fs c1 xx1 zz1 in
	  let finish = Unix.gettimeofday() in
	  let _ = !info_print_level < 1 or (report0 " done\n"; true) in
	    start, finish, result in
	normalize_result norm_flag v1 eq_th1 domain_sub_th result,
  {total_time = finish -. total_start; formal_verification_time = finish -. start; certificate = stats};;
	

(* A simple verification function which accepts 
   a list of rewrite theorems which are applied to the inequality
   before verification *)
let verify_ineq_and_rewrite rewrite_thms params pp ineq_tm =
  let ineq, vars, bounds = dest_ineq ineq_tm in
    verify_ineq0 params true pp ineq vars bounds rewrite_thms;;


(* The simplest verification function *)
let verify_ineq = verify_ineq_and_rewrite [];;


end;;