(* =========================================================== *)
(* Auxiliary formal verification functions                     *)
(* Author: Alexey Solovyev                                     *)
(* Date: 2012-10-27                                            *)
(* =========================================================== *)

needs "arith/eval_interval.hl";;
needs "arith/more_float.hl";;
needs "taylor/m_taylor.hl";;
needs "verifier/interval_m/taylor.ml";;
needs "informal/informal_m_verifier.hl";;
needs "verifier_options.hl";;
needs "misc/vars.hl";;

module M_verifier_build = struct

open More_float;;
open Eval_interval;;
open M_taylor;;

open Interval_types;;
open Interval;;
open Line_interval;;
open Taylor;;

open Misc_vars;;
open Verifier_options;;

type verification_funs =
{
  (* p_lin -> p_second -> domain_th -> taylor_th *)
  taylor : int -> int -> thm -> thm;
  (* pp -> lo -> hi -> interval_th *)
  f : int -> term -> term -> thm;
  (* i -> pp -> lo -> hi -> interval_th *)
  df : int -> int -> term -> term -> thm;
  (* i -> j -> pp -> lo -> hi -> interval_th *)
  ddf : int -> int -> int -> term -> term -> thm;
  (* lo -> hi -> diff2_th *)
  diff2_f : term -> term -> thm;
};;

(****************************)
(* Interval polynomial functions for the native OCaml arithmetic *)

type int_poly_fun =
  | F_int_var of int
  | F_int_const of interval
  | F_int_pow of int * int_poly_fun
  | F_int_neg of int_poly_fun
  | F_int_add of int_poly_fun * int_poly_fun
  | F_int_sub of int_poly_fun * int_poly_fun
  | F_int_mul of int_poly_fun * int_poly_fun;;


let ipow = Arith_misc.gen_pow imul Interval.one;;    


let eval_int_poly_fun i_fun =
  fun x ->
    let rec eval_rec f =
      match f with
	| F_int_var i -> List.nth x (i - 1)
	| F_int_const int -> int
	| F_int_neg f1 -> ineg (eval_rec f1)
	| F_int_pow (n,f1) -> ipow n (eval_rec f1)
	| F_int_add (f1,f2) -> iadd (eval_rec f1) (eval_rec f2)
	| F_int_sub (f1,f2) -> isub (eval_rec f1) (eval_rec f2)
	| F_int_mul (f1,f2) -> imul (eval_rec f1) (eval_rec f2) in
      eval_rec i_fun;;

(****************************)
(* Automatic conversion of formal interval polynomials into functions (polynomials) *)
(* TODO: take Int_ref into account *)

let rec build_poly_fun i_fun =
  match i_fun with
    | Int_var tm -> 
	(try F_int_var (dest_small_numeral (rand tm))
	 with Failure _ ->
	   let name = (fst o dest_var) tm in
	     F_int_var (int_of_string (String.sub name 1 (String.length name - 1))))
    | Int_const th ->
	let f1, f2 = (dest_pair o rand o concl) th in
	let int = mk_interval (float_of_float_tm f1, float_of_float_tm f2) in
	  F_int_const int
    | Int_pow (n, f) -> F_int_pow (n, build_poly_fun f)
    | Int_unary (op, f) ->
	let f' = build_poly_fun f in
	  if op = neg_op_real then F_int_neg f' else failwith ("Unsupported operator: "^string_of_term op)
    | Int_binary (op, f1, f2) ->
	let f1', f2' = build_poly_fun f1, build_poly_fun f2 in
	  if op = add_op_real then F_int_add (f1',f2')
	  else if op = sub_op_real then F_int_sub (f1',f2')
	  else if op = mul_op_real then F_int_mul (f1',f2') 
	  else failwith ("Unsupported operator: "^string_of_term op)
    | _ -> failwith "Unsupported function";;


let build_polyL pp lin_th =
  let funs = map (fst o dest_interval_arith) ((striplist dest_conj o rand o concl) lin_th) in
  let i_funs = map (eval_constants pp o build_interval_fun) funs in
  let fs = map build_poly_fun i_funs @ (replicate (F_int_const zero) (8 - length funs + 1)) in
  let eval_fs = map eval_int_poly_fun fs in
  let f, df = hd eval_fs, tl eval_fs in
    (fun i x z ->
      let vars = map2 (curry mk_interval) x z in
	if i = 0 then f vars else (List.nth df (i - 1)) vars),
    (fun x ->
      let vars = map (fun x -> mk_interval (x,x)) x in
	mk_line (f vars, map (fun df -> df vars) df));;

let build_polyL0 pp poly_tm =
  let lin_th = gen_lin_approx_poly_thm0 poly_tm in
    build_polyL pp lin_th;;

let build_polyDD pp second_th =
  let poly_tm = (lhand o rator o lhand o concl) second_th in
  let n = (get_dim o fst o dest_abs) poly_tm in
  let ns = 1--n in
  let funs = (striplist dest_conj o rand o snd o dest_forall o rand o concl) second_th in
  let i_funs = map (eval_constants pp o build_interval_fun o fst o dest_interval_arith) funs in
  let fs0 = map build_poly_fun i_funs in
  let pad1 = replicate zero (8 - n) and
      pad2 = replicate zero 8 in
  let pad3 = replicate pad2 (8 - n) in
  let get_el dd i j = 
    let i', j' = if j <= i then i, j else j, i in
    let index = (i' - 1) * i' / 2 + (j' - 1) in
      List.nth dd index in
  let eval_fs = map eval_int_poly_fun fs0 in
    fun x z ->
      let ints = map2 (curry mk_interval) x z in
      let vals = map (fun f -> f ints) eval_fs in
	map (fun i -> map (fun j -> get_el vals i j) ns @ pad1) ns @ pad3;;


let build_polyDD0 pp poly_tm =
  let second_th = gen_second_bounded_poly_thm0 poly_tm in
    build_polyDD pp second_th;;


(******)

let build_poly_taylor pp lin_th second_th =
  let f_df, lin = build_polyL pp lin_th and
      dd = build_polyDD pp second_th in
    Prim_a (make_primitiveA (f_df, lin, dd));;

let build_poly_taylor0 pp poly_tm =
  build_poly_taylor pp (gen_lin_approx_poly_thm0 poly_tm) (gen_second_bounded_poly_thm0 poly_tm);;


(**********************************)  
(* mk_verification_functions *)

let mk_verification_functions_poly pp0 poly_tm =
  let x_tm, body_tm = dest_abs poly_tm in
  let new_f = poly_tm in
  let n = get_dim x_tm in

  let _ = !info_print_level = 0 or (report0 (sprintf "Computing partial derivatives (%d)..." n); true) in
  let partials = map (fun i -> 
			let _ = !info_print_level = 0 or (report0 (sprintf " %d" i); true) in
			  gen_partial_poly i new_f) (1--n) in
  let get_partial i eq_th = 
    let partial_i = gen_partial_poly i (rand (concl eq_th)) in
    let pi = (rator o lhand o concl) partial_i in
      REWRITE_RULE[GSYM partial2] (TRANS (AP_TERM pi eq_th) partial_i) in
  let partials2 = map (fun j -> 
			 let th = List.nth partials (j - 1) in
			   map (fun i -> 
				  let _ = !info_print_level = 0 or (report0 (sprintf " %d,%d" j i); true) in
				    get_partial i th) (1--j)) (1--n) in

  let _ = !info_print_level = 0 or (report0 " done\n"; true) in

  let diff_th = gen_diff_poly new_f in
  let lin_th = gen_lin_approx_poly_thm new_f diff_th partials in
  let diff2_th = gen_diff2c_domain_poly new_f in
  let second_th = gen_second_bounded_poly_thm new_f partials2 in

  let replace_numeral i th =
    let num_eq = (REWRITE_RULE[Arith_hash.NUM_THM] o Arith_nat.NUMERAL_TO_NUM_CONV) 
      (mk_small_numeral i) in
      GEN_REWRITE_RULE (LAND_CONV o RATOR_CONV o DEPTH_CONV) [num_eq] th in

  let eval0 = mk_eval_function pp0 new_f in
  let eval1 = map (fun i -> 
		     let d_th = List.nth partials (i - 1) in
		     let eq_th = replace_numeral i d_th in
		       mk_eval_function_eq pp0 eq_th) (1--n) in

  let eval2 = map (fun i ->
		     map (fun j ->
			    let d2_th = List.nth (List.nth partials2 (i - 1)) (j - 1) in
			    let eq_th' = replace_numeral i d2_th in
			    let eq_th = replace_numeral j eq_th' in
			      mk_eval_function_eq pp0 eq_th) (1--i)) (1--n) in

  let diff2_f = eval_diff2_poly diff2_th in
  let eval_f = eval_m_taylor pp0 diff2_th lin_th second_th in
  let taylor_f = build_poly_taylor pp0 lin_th second_th in
    {taylor = eval_f;
     f = eval0;
     df = (fun i -> List.nth eval1 (i - 1));
     ddf = (fun i j -> List.nth (List.nth eval2 (j - 1)) (i - 1));
     diff2_f = diff2_f;
    }, taylor_f, Informal_verifier.mk_verification_functions_poly pp0 new_f partials partials2;;


	
  
end;;