needs "../formal_lp/formal_interval/interval_1d/recurse.hl";;
needs "../formal_lp/formal_interval/second_approx.hl";;
needs "../formal_lp/formal_interval/verifier.hl";;


open Taylor;;
open Univariate;;
open Interval;;
open Recurse;;

let x_var_real = `x:real` and
    add_op_real = `(+):real->real->real` and
    sub_op_real = `(-):real->real->real` and
    mul_op_real = `( * ):real->real->real` and
    div_op_real = `(/):real->real->real` and
    neg_op_real = `(--):real->real`;;

let taylor_interval_div_alt pp th1 th2 =
  let domain_th, _, _ = dest_taylor_interval_th th1 in
  let rhs = taylor_interval_compose pp eval_taylor_inv (fun _ _ -> th2) in
    taylor_interval_mul pp th1 (rhs domain_th);;


let mk_formal_binop taylor_binop arg1 arg2 =
  fun pp domain_th -> 
    let lhs = arg1 pp domain_th and
	rhs = arg2 pp domain_th in
      taylor_binop pp lhs rhs;;


let rec eval_interval pp expr_tm =
  let ltm, r_tm = dest_comb expr_tm in
    if is_comb ltm then
      let op, l_tm = dest_comb ltm in
      let l_th = eval_interval pp l_tm in
      let r_th = eval_interval pp r_tm in
	if op = plus_op_real then
	  float_interval_add pp l_th r_th
	else if op = mul_op_real then
	  float_interval_mul pp l_th r_th
	else if op = div_op_real then
	  float_interval_div pp l_th r_th
	else if op = minus_op_real then
	  float_interval_sub pp l_th r_th
	else
	  failwith ("Unknown operation: " ^ (fst o dest_const) op)
    else
      if ltm = neg_op_real then
	let r_th = eval_interval pp r_tm in
	  float_interval_neg r_th
      else
	mk_float_interval_num (dest_numeral r_tm);;


let rat_term_of_term tm =
  let th0 = REWRITE_CONV[DECIMAL] tm in
  let th1 = (REAL_RAT_REDUCE_CONV o rand o concl) th0 in
    GEN_REWRITE_RULE RAND_CONV [th1] th0;;


let float_of_term tm =
    (float_of_num o rat_of_term o rand o concl o rat_term_of_term) tm;;


let interval_of_term pp tm =
  let eq_th = rat_term_of_term tm in
  let int_th = eval_interval pp (rand (concl eq_th)) in
    REWRITE_RULE[SYM eq_th] int_th;;


let rec mk_taylor_functions pp expr_tm =
  (* Variable *)
  if is_var expr_tm then
    if (expr_tm = x_var_real) then x1, (fun pp -> eval_taylor_x)
    else failwith "mk_taylor_functions: only x variable is allowed"
  else
    (* Constant *)
    if is_const expr_tm then
      let name,_ = dest_const expr_tm in
	if name = "inv" then
	  Uni (mk_primitiveU uinv), eval_taylor_inv
	else if name = "sqrt" then
	  Uni (mk_primitiveU usqrt), eval_taylor_sqrt
	else if name = "atn" then
	  Uni (mk_primitiveU uatan), eval_taylor_atn
	else
	  failwith ("mk_taylor_functions: unknown constant: "^name)
    else
      (* Number *)
      try
	let rat_th = rat_term_of_term expr_tm in
	let rat_tm = rand (concl rat_th) in
	let n = (float_of_num o rat_of_term) rat_tm in
	let int_th =
	  let th0 = eval_interval pp rat_tm in
	    REWRITE_RULE[SYM rat_th] th0 in
	  Scale (unit, mk_interval (n, n)), (fun pp -> eval_taylor_const_interval int_th)
      with Failure _ ->
	(* Combination *)
	let lhs, r_tm = dest_comb expr_tm in
	let r, r2 = mk_taylor_functions pp r_tm in
          (* Unary operations *)
	  if lhs = neg_op_real then
	    Scale (r, ineg one), (fun pp -> eval_taylor_x)
	  else
	    (* Composite *)
	    try
	      let l, l2 = mk_taylor_functions pp lhs in
		if r_tm = x_var_real then 
		  l, l2 
		else 
		  Composite (l,r), (fun pp domain_th -> taylor_interval_compose pp l2 r2 domain_th)
	    with Failure _ ->
	      (* Binary operations *)
	      if is_comb lhs then
		let op, l_tm = dest_comb lhs in
		let l, l2 = mk_taylor_functions pp l_tm in
		  if op = add_op_real then
		    Plus (l,r), mk_formal_binop taylor_interval_add l2 r2
		  else if op = sub_op_real then
		    Minus (l,r), mk_formal_binop taylor_interval_sub l2 r2
		  else if op = mul_op_real then
		    Product (l,r), mk_formal_binop taylor_interval_mul l2 r2
		  else if op = div_op_real then
		    Product (l, Uni_compose (uinv,r)), mk_formal_binop taylor_interval_div_alt l2 r2
		  else
		    failwith ("mk_taylor_functions: unknown binary operation "^string_of_term op)
	      else
		failwith ("mk_taylor_functions: error: "^string_of_term lhs);;

(**********************************)

let rec result_size r =
  match r with
    | Result_false -> failwith "False result detected"
    | Result_glue (r1, r2) -> result_size r1 + result_size r2
    | Result_pass _ -> 1;;



let atan1 = `let P00, P01, P02 = #4.26665376811246382, #3.291955407318148, #0.281939359003812 in
  let Q00, Q01 = #4.26665376811354027, #4.714173328637605 in
  let x1 = x * x in
  let x2 = x1 * x1 in
    x * (x1 * P01 + x2 * P02 + P00) / (x1 * Q01 + x2 + Q00) - atn x - #0.001`;;
let atan1_tm = (rand o concl o REPEATC let_CONV) atan1;;

(* tan(pi/16) = 0.198912367379658 *)

let atn1_f, atn1_formal = mk_taylor_functions 15 atan1_tm;;
let result = 
  let opt = {
    allow_sharp = false;
    iteration_count = 0;
    iteration_limit = 0;
    recursion_depth = 20
  } in
    recursive_verifier (0, 0.0, 0.2, atn1_f, opt);;

result_size result;;


let prove_result pp formal_f result lo_tm hi_tm =
  let pass_counter = ref 0 in
  let r_size = result_size result in
  let rec rec_prove result lo_tm hi_tm =
    let domain_th = mk_center_domain pp lo_tm hi_tm in
      match result with
	| Result_false -> failwith "False result"
	| Result_pass _ ->
	    pass_counter := !pass_counter + 1;
	    let th0 = taylor_cell_pass pp (formal_f pp domain_th) in
	    let _ = report (sprintf "Result_pass: %d/%d" !pass_counter r_size) in
	      th0
      | Result_glue (r1, r2) ->
	  let _, y_tm, _, _ = dest_cell_domain (concl domain_th) in
	  let th1 = rec_prove r1 lo_tm y_tm and
	      th2 = rec_prove r2 y_tm hi_tm in
	    cell_pass_glue th1 th2 in
    rec_prove result lo_tm hi_tm;;



let lo_tm, hi_tm = mk_float 0 min_exp, mk_float 2 (min_exp - 1);;
prove_result 15 atn1_formal result lo_tm hi_tm;;


let CELL_PASS_IMP = 
prove (`cell_pass f (lo, hi) ==> interval_arith x (lo, w1) ==> interval_arith z (w2, hi) ==> !t. t IN real_interval [x, z] ==> f t < &0`,
REWRITE_TAC[cell_pass; interval_arith; IN_REAL_INTERVAL] THEN REPEAT STRIP_TAC THEN FIRST_X_ASSUM MATCH_MP_TAC THEN ASM_REAL_ARITH_TAC);;
let prove_ineq1d pp expr_tm lo_tm hi_tm = let f, formal_f = mk_taylor_functions pp expr_tm in let lo, hi = float_of_term lo_tm, float_of_term hi_tm in let lo_int, hi_int = interval_of_term pp lo_tm, interval_of_term pp hi_tm in let lo1_tm = (fst o dest_pair o rand o concl) lo_int and hi1_tm = (snd o dest_pair o rand o concl) hi_int in let _ = report (sprintf "Starting verification from %f to %f" lo hi) in let result = let opt = { allow_sharp = false; iteration_count = 0; iteration_limit = 0; recursion_depth = 50 } in recursive_verifier (0, lo, hi, f, opt) in
let th0 = prove_result pp formal_f result lo1_tm hi1_tm in
    REWRITE_RULE[] (MATCH_MP (MATCH_MP (MATCH_MP CELL_PASS_IMP th0) lo_int) hi_int);;
prove_ineq1d 10 `atn x - x` `#0.001` `#30`;; result_size result;; let domain_th = mk_center_domain 10 (mk_float 175 (min_exp - 3)) (mk_float 2 (min_exp - 1));; let th1 = atn1_formal 10 domain_th;; test 1 (atn1_formal 10) domain_th;; evalf atn1_f 0.0 0.025;; eval_taylor_f_bounds 10 th1;; (**********************************) let pp = 10;; let domain_th = mk_center_domain pp (mk_float 1 (min_exp - 1)) (mk_float 8 (min_exp - 1));; let f, f2 = mk_taylor_functions 10 `atn x / x - &1 / &3`;; evalf f 0.1 0.8;; f2 pp domain_th;; let opt = { allow_sharp = false; iteration_count = 0; iteration_limit = 0; recursion_depth = 100 } in recursive_verifier (0, 0.1, 0.8, f, opt);; let int1 = f2 pp (mk_center_domain pp (mk_float 1 49) (mk_float 14375 45));; eval_taylor_f_bounds pp int1;;