(* Dependencies *)
needs "../formal_lp/arith/informal/informal_arith.hl";;
needs "../formal_lp/arith/informal/informal_eval_interval.hl";;


module Informal_taylor = struct

open Informal_interval;;
open Informal_float;;
open Informal_atn;;
open Informal_eval_interval;;


type m_cell_domain = 
{
  lo : ifloat list;
  hi : ifloat list;
  y : ifloat list;
  w : ifloat list;
};;


type m_taylor_interval =
{
  n : int;
  domain : m_cell_domain;
  f : interval;
  df : interval list;
  ddf : interval list list; 
};;


let float_0 = mk_small_num_float 0 and
    float_1 = mk_small_num_float 1 and
    float_2 = mk_small_num_float 2;;

let float_inv2 = div_float_lo 1 float_1 float_2;;

(* convert_to_float_list *)
let convert_to_float_list pp lo_flag list_tm =
  let tms = dest_list list_tm in
  let i_funs = map build_interval_fun tms in
  let ints = map (fun f -> eval_interval_fun pp f [] []) i_funs in
  let extract = (if lo_flag then fst else snd) o dest_interval in
    map extract ints;;


(* mk_m_center_domain *)
let mk_m_center_domain pp x_list z_list =
  let y_list =
    let ( * ), (+) = mul_float_eq, add_float_hi pp in
      map2 (fun x z -> if eq_float x z then x else float_inv2 * (x + z)) x_list z_list in

  (* test: x <= y <= z *)
  let flag1 = itlist2 (fun x y a -> le_float x y && a) x_list y_list true and
      flag2 = itlist2 (fun y z a -> le_float y z && a) y_list z_list true in
    if not flag1 or not flag2 then
      failwith "mk_m_center_domain: ~(x <= y <= z)"
    else
      let w_list =
	let (-) = sub_float_hi pp in
	let w1 = map2 (-) y_list x_list in
	let w2 = map2 (-) z_list y_list in
	  map2 max_float w1 w2 in
	{lo = x_list; hi = z_list; y = y_list; w = w_list};;


(* eval_m_taylor (pp0 for initial evaluation of constants) *)
let eval_m_taylor pp0 f_tm partials partials2 =
  let build = eval_constants pp0 o build_interval_fun o snd o dest_abs in
  let f = build f_tm in
  let n = length partials in
  (* Verify that the list of second partial derivatives is correct *)
  let _ =  map2 (fun i list -> if length list <> i then 
		   failwith "eval_m_taylor: incorrect partials2" else ()) (1--n) partials2 in
  let dfs = map (build o rand o concl) partials in
  let d2fs = map (build o rand o concl) (List.flatten partials2) in
  let f_dfs_list = find_and_replace_all (f :: dfs) [] in
  let rec shape_list dd i =
    if i >= n then [dd] else
      let l1, l2 = chop_list i dd in
	l1 :: shape_list l2 (i + 1) in
  let d2fs_list = find_and_replace_all d2fs [] in
    fun p_lin p_second domain ->
      let y_ints = map (fun y -> mk_interval (y, y)) domain.y in
      let xz_ints = map mk_interval (zip domain.lo domain.hi) in
      let f_dfs_vals = eval_interval_fun_list p_lin f_dfs_list y_ints in
      let d2fs_vals = eval_interval_fun_list p_second d2fs_list xz_ints in
	{n = n; domain = domain;
	 f = hd f_dfs_vals; df = tl f_dfs_vals;
	ddf = shape_list d2fs_vals 1};;


(* mk_eval_functionq *)
let mk_eval_function pp0 f_tm =
  let build = eval_constants pp0 o build_interval_fun o snd o dest_abs in
  let f = build f_tm in
  let f_list = find_and_replace_all [f] [] in
    fun pp x_list z_list ->
      let xz_ints = map mk_interval (zip x_list z_list) in
      let f_val = eval_interval_fun_list pp f_list xz_ints in
	hd f_val;;


(* error_mul_f2_hi *)
let error_mul_f2_hi pp a int = mul_float_hi pp a (abs_interval int);;


(* eval_m_taylor_error *)
(* sum_{i = 1}^n (w_i * (f_ii * w_i + 2 * sum_{j = 1}^{i - 1} w_j * f_ij)) *)
let eval_m_taylor_error pp ti =
  let w = ti.domain.w in
  let ns = 1--ti.n in
  let ( * ), ( + ) = mul_float_hi pp, add_float_hi pp in
  let mul_wdd = map2 (fun list i -> Arith_misc.my_map2 (error_mul_f2_hi pp) w list) ti.ddf ns in
  let sums1 = map (end_itlist ( + ) o butlast) (tl mul_wdd) in
  let sums2 = (hd o hd) mul_wdd :: map2 (fun list t1 -> last list + float_2 * t1) (tl mul_wdd) sums1 in
  let sums = map2 ( * ) w sums2 in
    end_itlist ( + ) sums;;
    

(* eval_m_taylor_upper_bound *)
let eval_m_taylor_upper_bound pp ti =
  let f_hi = (snd o dest_interval) ti.f in
  let error = eval_m_taylor_error pp ti in
  let ( * ), ( + ) = mul_float_hi pp, add_float_hi pp in
  let sum2 =
    let mul_wd = map2 (error_mul_f2_hi pp) ti.domain.w ti.df in
      end_itlist ( + ) mul_wd in
  let a = sum2 + float_inv2 * error in
    f_hi + a;;

(* eval_m_taylor_lower_bound *)
let eval_m_taylor_lower_bound pp ti =
  let f_lo = (fst o dest_interval) ti.f in
  let error = eval_m_taylor_error pp ti in
  let ( * ), ( + ), ( - ) = mul_float_hi pp, add_float_hi pp, sub_float_lo pp in
  let sum2 =
    let mul_wd = map2 (error_mul_f2_hi pp) ti.domain.w ti.df in
      end_itlist ( + ) mul_wd in
  let a = sum2 + float_inv2 * error in
    f_lo - a;;


(* eval_m_taylor_bound *)
let eval_m_taylor_bound pp ti =
  let f_lo, f_hi = dest_interval ti.f in
  let error = eval_m_taylor_error pp ti in
  let ( * ), ( + ), ( - ) = mul_float_hi pp, add_float_hi pp, sub_float_lo pp in
  let sum2 =
    let mul_wd = map2 (error_mul_f2_hi pp) ti.domain.w ti.df in
      end_itlist ( + ) mul_wd in
  let a = sum2 + float_inv2 * error in
  let hi = f_hi + a in
  let lo = f_lo - a in
    mk_interval (lo, hi);;


(* eval_m_taylor_partial_upper *)
let eval_m_taylor_partial_upper pp i ti =
  let df_hi = (snd o dest_interval o List.nth ti.df) (i - 1) in
  let dd_list = map (fun j -> if j <= i then
		       List.nth (List.nth ti.ddf (i - 1)) (j - 1) 
		     else
		       List.nth (List.nth ti.ddf (j - 1)) (i - 1)) (1--ti.n) in
  let sum2 = 
    let mul_dd = map2 (error_mul_f2_hi pp) ti.domain.w dd_list in
      end_itlist (add_float_hi pp) mul_dd in
    add_float_hi pp df_hi sum2;;


(* eval_m_taylor_partial_lower *)
let eval_m_taylor_partial_lower pp i ti =
  let df_lo = (fst o dest_interval o List.nth ti.df) (i - 1) in
  let dd_list = map (fun j -> if j <= i then
		       List.nth (List.nth ti.ddf (i - 1)) (j - 1) 
		     else
		       List.nth (List.nth ti.ddf (j - 1)) (i - 1)) (1--ti.n) in
  let sum2 = 
    let mul_dd = map2 (error_mul_f2_hi pp) ti.domain.w dd_list in
      end_itlist (add_float_hi pp) mul_dd in
    sub_float_lo pp df_lo sum2;;


(* eval_m_taylor_partial_bound *)
let eval_m_taylor_partial_bound pp i ti =
  let df_lo, df_hi = (dest_interval o List.nth ti.df) (i - 1) in
  let dd_list = map (fun j -> if j <= i then
		       List.nth (List.nth ti.ddf (i - 1)) (j - 1) 
		     else
		       List.nth (List.nth ti.ddf (j - 1)) (i - 1)) (1--ti.n) in
  let sum2 = 
    let mul_dd = map2 (error_mul_f2_hi pp) ti.domain.w dd_list in
      end_itlist (add_float_hi pp) mul_dd in
  let lo = sub_float_lo pp df_lo sum2 in
  let hi = add_float_hi pp df_hi sum2 in
    mk_interval (lo, hi);;


(* add *)
let eval_m_taylor_add p_lin p_second taylor1 taylor2 =
  let ( + ), ( ++ ) = add_interval p_lin, add_interval p_second in
    {
      n = taylor1.n;
      domain = taylor1.domain;
      f = taylor1.f + taylor2.f;
      df = map2 (+) taylor1.df taylor2.df;
      ddf = map2 (map2 (++)) taylor1.ddf taylor2.ddf
    };;


(* sub *)
let eval_m_taylor_sub p_lin p_second taylor1 taylor2 =
  let ( - ), ( -- ) = sub_interval p_lin, sub_interval p_second in
    {
      n = taylor1.n;
      domain = taylor1.domain;
      f = taylor1.f - taylor2.f;
      df = map2 (-) taylor1.df taylor2.df;
      ddf = map2 (map2 (--)) taylor1.ddf taylor2.ddf
    };;
   

(* mul *)
let eval_m_taylor_mul p_lin p_second ti1 ti2 =
  let n = ti1.n in
  let ns = 1--n in
  let bounds = mul_interval p_lin ti1.f ti2.f in
  let df = map2 (fun d1 d2 ->
		   let ( * ), ( + ) = mul_interval p_lin, add_interval p_lin in
		     d1 * ti2.f + ti1.f * d2) ti1.df ti2.df in
  let d1_bounds = map (fun i -> eval_m_taylor_partial_bound p_second i ti1) ns in
  let d2_bounds = map (fun i -> eval_m_taylor_partial_bound p_second i ti2) ns in
  let f1_bound = eval_m_taylor_bound p_second ti1 in
  let f2_bound = eval_m_taylor_bound p_second ti2 in
  let ddf = 
    let ( * ), ( + ) = mul_interval p_second, add_interval p_second in
      map2 (fun (list1, list2) i ->
	      let di1 = List.nth d1_bounds (i - 1) in
	      let di2 = List.nth d2_bounds (i - 1) in
		map2 (fun (dd1, dd2) j ->
			let dj1 = List.nth d1_bounds (j - 1) in
			let dj2 = List.nth d2_bounds (j - 1) in
			  (dd1 * f2_bound + di1 * dj2) + (dj1 * di2 + f1_bound * dd2))
		  (zip list1 list2) (1--i)) (zip ti1.ddf ti2.ddf) ns in
    {
      n = n;
      domain = ti1.domain;
      f = bounds;
      df = df;
      ddf = ddf;
    };;


(* inv *)
let eval_m_taylor_inv p_lin p_second ti =
  let n = ti.n in
  let ns = 1--n in
  let f1_bound = eval_m_taylor_bound p_second ti in
  let bounds = inv_interval p_lin ti.f in
  let u_bounds =
    let neg, inv, ( * ) = neg_interval, inv_interval p_lin, mul_interval p_lin in
      neg (inv (ti.f * ti.f)) in
  let df =
    let ( * ) = mul_interval p_lin in
      map (fun d -> u_bounds * d) ti.df in
  let d1_bounds = map (fun i -> eval_m_taylor_partial_bound p_second i ti) ns in
  let d1, d2 =
    let inv, ( * ) = inv_interval p_second, mul_interval p_second in
    let ff = f1_bound * f1_bound in
      inv ff, two_interval * inv (f1_bound * ff) in
  let ddf = 
    let ( * ), ( - ) = mul_interval p_second, sub_interval p_second in
      map2 (fun dd_list di1 ->
	      Arith_misc.my_map2 (fun dd dj1 ->
				    (d2 * dj1) * di1 - d1 * dd) dd_list d1_bounds) ti.ddf d1_bounds in
    {
      n = n;
      domain = ti.domain;
      f = bounds;
      df = df;
      ddf = ddf;
    };;


(* sqrt *)
let eval_m_taylor_sqrt p_lin p_second ti =
  let n = ti.n in
  let ns = 1--n in
  let f1_bound = eval_m_taylor_bound p_second ti in
  let bounds = sqrt_interval p_lin ti.f in
  let u_bounds =
    let inv, ( * ) = inv_interval p_lin, mul_interval p_lin in
      inv (two_interval * bounds) in
  let df =
    let ( * ) = mul_interval p_lin in
      map (fun d -> u_bounds * d) ti.df in
  let d1_bounds = map (fun i -> eval_m_taylor_partial_bound p_second i ti) ns in
  let d1, d2 =
    let neg, sqrt, inv, ( * ) = neg_interval, sqrt_interval p_second, 
      inv_interval p_second, mul_interval p_second in
    let two_sqrt_f = two_interval * sqrt f1_bound in
      inv two_sqrt_f, neg (inv (two_sqrt_f * (two_interval * f1_bound))) in
  let ddf = 
    let ( * ), ( + ) = mul_interval p_second, add_interval p_second in
      map2 (fun dd_list di1 ->
	      Arith_misc.my_map2 (fun dd dj1 ->
				    (d2 * dj1) * di1 + d1 * dd) dd_list d1_bounds) ti.ddf d1_bounds in
    {
      n = n;
      domain = ti.domain;
      f = bounds;
      df = df;
      ddf = ddf;
    };;


(* atn *)
let eval_m_taylor_atn =
  let neg_two_interval = neg_interval two_interval in
    fun p_lin p_second ti ->
      let n = ti.n in
      let ns = 1--n in
      let f1_bound = eval_m_taylor_bound p_second ti in
      let bounds = atn_interval p_lin ti.f in
      let u_bounds =
	let inv, ( + ), ( * ) = inv_interval p_lin, add_interval p_lin, mul_interval p_lin in
	  inv (one_interval + ti.f * ti.f) in
      let df =
	let ( * ) = mul_interval p_lin in
	  map (fun d -> u_bounds * d) ti.df in
      let d1_bounds = map (fun i -> eval_m_taylor_partial_bound p_second i ti) ns in
      let d1, d2 =
	let neg, inv, ( + ), ( * ) = neg_interval, inv_interval p_second, 
	  add_interval p_second, mul_interval p_second in
	let pow2 = pow_interval p_second 2 in
	let inv_one_ff = inv (one_interval + f1_bound * f1_bound) in
	  inv_one_ff, (neg_two_interval * f1_bound) * pow2 inv_one_ff in
      let ddf = 
	let ( * ), ( + ) = mul_interval p_second, add_interval p_second in
	  map2 (fun dd_list di1 ->
		  Arith_misc.my_map2 (fun dd dj1 ->
					(d2 * dj1) * di1 + d1 * dd) dd_list d1_bounds) ti.ddf d1_bounds in
	{
	  n = n;
	  domain = ti.domain;
	  f = bounds;
	  df = df;
	  ddf = ddf;
	};;


(* acs *)
let eval_m_taylor_acs p_lin p_second ti =
  let n = ti.n in
  let ns = 1--n in
  let f1_bound = eval_m_taylor_bound p_second ti in
  let bounds = acs_interval p_lin ti.f in
  let u_bounds =
    let inv, sqrt, neg = inv_interval p_lin, sqrt_interval p_lin, neg_interval in
    let ( * ), ( - ) = mul_interval p_lin, sub_interval p_lin in
      neg (inv (sqrt (one_interval - ti.f * ti.f))) in
  let df =
    let ( * ) = mul_interval p_lin in
      map (fun d -> u_bounds * d) ti.df in
  let d1_bounds = map (fun i -> eval_m_taylor_partial_bound p_second i ti) ns in
  let d1, d2 =
    let neg, sqrt, inv = neg_interval, sqrt_interval p_second, inv_interval p_second in
    let ( - ), ( * ), ( / ) = sub_interval p_second, mul_interval p_second, div_interval p_second in
    let pow3 = pow_interval p_second 3 in
    let ff_1 = one_interval - f1_bound * f1_bound in
      inv (sqrt ff_1), neg (f1_bound / sqrt (pow3 ff_1)) in
  let ddf = 
    let ( * ), ( - ) = mul_interval p_second, sub_interval p_second in
      map2 (fun dd_list di1 ->
	      Arith_misc.my_map2 (fun dd dj1 ->
				    (d2 * dj1) * di1 - d1 * dd) dd_list d1_bounds) ti.ddf d1_bounds in
    {
      n = n;
      domain = ti.domain;
      f = bounds;
      df = df;
      ddf = ddf;
    };;
		    

end;;


(*
(* Tests *)

open Informal_taylor;;

let dest_int int =
  let f1, f2 = Informal_interval.dest_interval int in
    Informal_float.dest_float f1, Informal_float.dest_float f2;;

let dest_ti ti =
  dest_int ti.f, map dest_int ti.df, map (map dest_int) ti.ddf;;

let dest_f = Informal_float.dest_float;;


needs "../formal_lp/formal_interval/m_taylor_arith2.hl";;

let convert_to_float_list pp lo_flag list_tm =
  let tms = dest_list list_tm in
  let i_funs = map build_interval_fun tms in
  let ints = map (fun f -> eval_interval_fun pp f [] []) i_funs in
  let extract = (if lo_flag then fst else snd) o dest_pair o rand o concl in
    mk_list (map extract ints, real_ty);;


let pp = 7;;
let poly = expr_to_vector_fun `x1 + x2 * x3 + x3 * (x1 + x3 pow 2)`;;
let n = (get_dim o fst o dest_abs) poly;;

let xx = `[#1.1; &2; -- sqrt(&2)]` and
    zz = `[&3; &3; &1 + sqrt(&3)]`;;

let xx1 = convert_to_float_list pp true xx and
    zz1 = convert_to_float_list pp false zz;;

let xx0 = Informal_taylor.convert_to_float_list pp true xx and
    zz0 = Informal_taylor.convert_to_float_list pp false zz;;

let dom_th = mk_m_center_domain n pp xx1 zz1;;
let dom = Informal_taylor.mk_m_center_domain pp xx0 zz0;;

let partials = map (fun i -> gen_partial_poly i poly) (1--n);;
let get_partial i eq_th =
  let partial_i = gen_partial_poly i (rand (concl eq_th)) in
  let pi = (rator o lhand o concl) partial_i in
    REWRITE_RULE[GSYM partial2] (TRANS (AP_TERM pi eq_th) partial_i);;
let partials2 = map (fun j ->
		       let th = List.nth partials (j - 1) in
			 map (fun i -> get_partial i th) (1--j)) (1--n);;

let diff_th = gen_diff_poly poly;;
let diff2_th = gen_diff2c_domain_poly poly;;
let lin_th = gen_lin_approx_poly_thm poly diff_th partials;;
let second_th = gen_second_bounded_poly_thm poly partials2;;

let eval_taylor = eval_m_taylor pp diff2_th lin_th second_th;;
let taylor = Informal_taylor.eval_m_taylor pp poly partials partials2;;

let ti_th = eval_taylor pp pp dom_th;;
let ti = taylor pp pp dom;;
dest_ti ti;;

eval_m_taylor_bound n pp ti_th;;
dest_int (Informal_taylor.eval_m_taylor_bound pp ti);;

eval_m_taylor_partial_upper n pp 3 ti_th;;
dest_f (Informal_taylor.eval_m_taylor_partial_upper pp 3 ti);;

let t2_th = eval_m_taylor_sub n 2 5 ti_th ti_th;;
let t2 = Informal_taylor.eval_m_taylor_sub 2 5 ti ti;;
dest_ti t2;;

eval_m_taylor_sub n 8 8 ti_th t2_th;;
dest_ti (Informal_taylor.eval_m_taylor_sub 8 8 ti t2);;

let xx = `[#0.0; &0; sqrt(&0)]` and
    zz = `[#0.2; #0.1; sqrt(&0) + #0.1]`;;

let xx1 = convert_to_float_list pp true xx and
    zz1 = convert_to_float_list pp false zz;;

let xx0 = Informal_taylor.convert_to_float_list pp true xx and
    zz0 = Informal_taylor.convert_to_float_list pp false zz;;

let dom_th = mk_m_center_domain n pp xx1 zz1;;
let dom = Informal_taylor.mk_m_center_domain pp xx0 zz0;;


let ti_th = eval_taylor pp pp dom_th;;
let ti = taylor pp pp dom;;
let th = eval_m_taylor_acs n pp pp ti_th;;
let t = Informal_taylor.eval_m_taylor_acs pp pp ti;;
dest_ti t;;

eval_m_taylor_bound n 20 th;;
dest_int (Informal_taylor.eval_m_taylor_bound 20 t);;

eval_m_taylor_partial_bound n 20 2 th;;
dest_int (Informal_taylor.eval_m_taylor_partial_bound 20 2 t);;

eval_m_taylor_mul n pp pp ti_th th;;
dest_ti (Informal_taylor.eval_m_taylor_mul pp pp ti t);;


(* 1.20 *)
test 100 eval_taylor dom_th;;
(* 0.04 *)
test 100 taylor dom;;

(* bounds *)
eval_m_taylor_bound n pp ti_th;;
dest_int (Informal_taylor.eval_m_taylor_bound pp ti);;

eval_m_taylor_upper_bound n pp ti_th;;
dest_f (Informal_taylor.eval_m_taylor_upper_bound pp ti);;

eval_m_taylor_lower_bound n pp ti_th;;
dest_f (Informal_taylor.eval_m_taylor_lower_bound pp ti);;


(* 1.288 *)
test 100 (eval_m_taylor_bound n pp) ti_th;;
(* 0.044 *)
test 100 (Informal_taylor.eval_m_taylor_bound pp) ti;;



(* partials *)

eval_m_taylor_upper_partial n pp 1 ti_th;;
dest_f (Informal_taylor.eval_m_taylor_upper_partial pp 1 ti);;

eval_m_taylor_upper_partial n pp 2 ti_th;;
dest_f (Informal_taylor.eval_m_taylor_upper_partial pp 2 ti);;

eval_m_taylor_upper_partial n pp 3 ti_th;;
dest_f (Informal_taylor.eval_m_taylor_upper_partial pp 3 ti);;


eval_m_taylor_lower_partial n pp 1 ti_th;;
dest_f (Informal_taylor.eval_m_taylor_lower_partial pp 1 ti);;

eval_m_taylor_lower_partial n pp 2 ti_th;;
dest_f (Informal_taylor.eval_m_taylor_lower_partial pp 2 ti);;

eval_m_taylor_lower_partial n pp 3 ti_th;;
dest_f (Informal_taylor.eval_m_taylor_lower_partial pp 3 ti);;

eval_m_taylor_interval_partial n pp 1 ti_th;;
dest_int (Informal_taylor.eval_m_taylor_interval_partial pp 1 ti);;
*)