flyspeck_needs "../formal_lp/formal_interval/interval_m/verifier.hl";;
flyspeck_needs "../formal_lp/formal_interval/eval_interval.hl";;

open Verifier;;
open Interval;;
open Line_interval;;
open Taylor;;
open Recurse;;

(* Automatic conversion of formal interval functions into functions *)

let rec build_fun i_fun =
  match i_fun with
    | Int_var tm -> 
	(try F_int_var (dest_small_numeral (rand tm))
	 with Failure _ ->
	   let name = (fst o dest_var) tm in
	     F_int_var (int_of_string (String.sub name 1 (String.length name - 1))))
    | Int_const th ->
	let f1, f2 = (dest_pair o rand o concl) th in
	let int = mk_interval (float_of_float_tm f1, float_of_float_tm f2) in
	  F_int_const int
    | Int_pow (n, f) -> F_int_pow (n, build_fun f)
    | Int_unary (op, f) ->
	let f' = build_fun f in
	  if op = neg_op_real then F_int_neg f' else failwith ("Unsupported operator: "^string_of_term op)
    | Int_binary (op, f1, f2) ->
	let f1', f2' = build_fun f1, build_fun f2 in
	  if op = add_op_real then F_int_add (f1',f2')
	  else if op = sub_op_real then F_int_sub (f1',f2')
	  else if op = mul_op_real then F_int_mul (f1',f2') 
	  else failwith ("Unsupported operator: "^string_of_term op)
    | _ -> failwith "Unsupported function";;

let buildL pp lin_th =
  let funs = map (fst o dest_interval_arith) ((striplist dest_conj o rand o concl) lin_th) in
  let i_funs = map (eval_constants pp o build_interval_fun) funs in
  let fs = map build_fun i_funs @ (replicate (F_int_const zero) (8 - length funs + 1)) in
  let eval_fs = map eval_int_fun fs in
  let f, df = hd eval_fs, tl eval_fs in
    (fun i x z ->
      let vars = map2 (curry mk_interval) x z in
	if i = 0 then f vars else (List.nth df (i - 1)) vars),
    (fun x ->
      let vars = map (fun x -> mk_interval (x,x)) x in
	mk_line (f vars, map (fun df -> df vars) df));;

let buildL0 pp poly_tm =
  let lin_th = gen_lin_approx_poly_thm0 poly_tm in
    buildL pp lin_th;;

let buildDD pp second_th =
  let poly_tm = (lhand o rator o lhand o concl) second_th in
  let n = (get_dim o fst o dest_abs) poly_tm in
  let ns = 1--n in
  let funs = (striplist dest_conj o rand o snd o dest_forall o rand o concl) second_th in
  let i_funs = map (eval_constants pp o build_interval_fun o fst o dest_interval_arith) funs in
  let fs0 = map build_fun i_funs in
  let pad1 = replicate zero (8 - n) and
      pad2 = replicate zero 8 in
  let pad3 = replicate pad2 (8 - n) in
  let get_el dd i j = 
    let i', j' = if j <= i then i, j else j, i in
    let index = (i' - 1) * i' / 2 + (j' - 1) in
      List.nth dd index in
  let eval_fs = map eval_int_fun fs0 in
    fun x z ->
      let ints = map2 (curry mk_interval) x z in
      let vals = map (fun f -> f ints) eval_fs in
	map (fun i -> map (fun j -> get_el vals i j) ns @ pad1) ns @ pad3;;

let buildDD0 pp poly_tm =
  let second_th = gen_second_bounded_poly_thm0 poly_tm in
    buildDD pp second_th;;


let build_taylor pp lin_th second_th =
  let f_df, lin = buildL pp lin_th and
      dd = buildDD pp second_th in
    Prim_a (make_primitiveA (f_df, lin, dd));;

let build_taylor0 pp poly_tm =
  build_taylor pp (gen_lin_approx_poly_thm0 poly_tm) (gen_second_bounded_poly_thm0 poly_tm);;