needs "../formal_lp/formal_interval/more_float.hl";;

let x_var_num = `x:num` and
    y_var_num = `y:num`;;

let amp_op_real = `(&):num->real`;;

(* Creates an interval approximation of the given decimal term *)
let mk_float_interval_decimal =
  let DECIMAL' = SPEC_ALL DECIMAL in
    fun pp decimal_tm ->
      let n_tm, d_tm = dest_binary "DECIMAL" decimal_tm in
      let n, d = dest_numeral n_tm, dest_numeral d_tm in
      let n_int, d_int = mk_float_interval_num n, mk_float_interval_num d in
      let int = float_interval_div pp n_int d_int in
      let eq_th = INST[n_tm, x_var_num; d_tm, y_var_num] DECIMAL' in
	norm_interval int eq_th;;


(* Unary interval operations *)
let unary_interval_operations = 
  let table = Hashtbl.create 10 in
  let add = Hashtbl.add in
    add table `--` (fun pp -> float_interval_neg);
    add table `inv` float_interval_inv;
    add table `sqrt` float_interval_sqrt;
    add table `atn` float_interval_atn;
    add table `acs` float_interval_acs;
    table;;


(* Binary interval operations *)
let binary_interval_operations = 
  let table = Hashtbl.create 10 in
  let add = Hashtbl.add in
    add table `+` float_interval_add;
    add table `-` float_interval_sub;
    add table `*` float_interval_mul;
    add table `/` float_interval_div;
    table;;


(* Interval approximations of constants *)
let interval_constants =
  let table = Hashtbl.create 10 in
  let add = Hashtbl.add in
    add table `pi` (fun pp -> pi_approx_array.(pp));
    table;;



(* Type of an interval function *)
type interval_fun =
  | Int_ref of int
  | Int_var of term
  | Int_const of thm
  | Int_decimal_const of term
  | Int_named_const of term
  | Int_pow of int * interval_fun
  | Int_unary of term * interval_fun
  | Int_binary of term * interval_fun * interval_fun;;


(* Evaluates the given interval function at the point
   defined by the given list of variables *)
let eval_interval_fun pp ifun vars refs =
  let u_find = Hashtbl.find unary_interval_operations and
      b_find = Hashtbl.find binary_interval_operations and
      c_find = Hashtbl.find interval_constants in
  let rec rec_eval f =
    match f with
	  | Int_ref i -> List.nth refs i
      | Int_var tm -> assoc tm vars
      | Int_const th -> th
      | Int_decimal_const tm -> mk_float_interval_decimal pp tm
      | Int_named_const tm -> c_find tm pp
      | Int_pow (n,f1) -> float_interval_pow_simple pp n (rec_eval f1)
      | Int_unary (tm,f1) -> u_find tm pp (rec_eval f1)
      | Int_binary (tm,f1,f2) -> b_find tm pp (rec_eval f1) (rec_eval f2) in
    rec_eval ifun;;


(* Evaluates all sub-expressions involving constants in the given interval function *)
let eval_constants pp ifun =
  let u_find = Hashtbl.find unary_interval_operations and
      b_find = Hashtbl.find binary_interval_operations and
      c_find = Hashtbl.find interval_constants in
  let rec rec_eval f =
    match f with
      | Int_decimal_const tm -> Int_const (mk_float_interval_decimal pp tm)
      | Int_named_const tm -> Int_const (c_find tm pp)
      | Int_pow (n,f1) -> 
	  (let f1_val = rec_eval f1 in
	     match f1_val with
	       | Int_const th -> Int_const (float_interval_pow_simple pp n th)
	       | _ -> Int_pow (n,f1_val))
      | Int_unary (tm,f1) ->
	  (let f1_val = rec_eval f1 in
	     match f1_val with
	       | Int_const th -> Int_const (u_find tm pp th)
	       | _ -> Int_unary (tm, f1_val))
      | Int_binary (tm,f1,f2) ->
	  (let f1_val, f2_val = rec_eval f1, rec_eval f2 in
	     match f1_val with
	       | Int_const th1 ->
		   (match f2_val with
		      | Int_const th2 -> Int_const (b_find tm pp th1 th2)
		      | _ -> Int_binary (tm, f1_val, f2_val))
	       | _ -> Int_binary (tm, f1_val, f2_val))
	  | _ -> f in
    rec_eval ifun;;



(**************************************)

(* Builds an interval function from the given term expression *)
let rec build_interval_fun expr_tm =
  if is_const expr_tm then
    (* Constant *)
    Int_named_const expr_tm
  else if is_var expr_tm then
    (* Variable *)
    Int_var expr_tm
  else
    let ltm, r_tm = dest_comb expr_tm in
      (* Unary operations *)
      if is_const ltm then
	(* & *)
	if ltm = amp_op_real then
	  let n = dest_numeral r_tm in
	    Int_const (mk_float_interval_num n)
	else 
	  let r_fun = build_interval_fun r_tm in
	    Int_unary (ltm, r_fun)
      else
	(* Binary operations *)
	let op, l_tm = dest_comb ltm in
	let name = (fst o dest_const) op in
	  if name = "DECIMAL" then
	    (* DECIMAL *)
	    Int_decimal_const expr_tm
	  else if name = "real_pow" then
	    (* pow *)
	    let n = dest_small_numeral r_tm in
	      Int_pow (n, build_interval_fun l_tm)
	  else if name = "$" then
	    (* $ *)
	    Int_var expr_tm
	  else
	    let lhs = build_interval_fun l_tm and
		rhs = build_interval_fun r_tm in
	      Int_binary (op, lhs, rhs);;


(*
let test_vars = [`x:real`, two_interval];;
let f = build_interval_fun `(&1 + &3 * pi) + sqrt (#3.13525238353 * x)`;;
let f2 = eval_constants pp f;;
eval_interval_fun pp f test_vars;;
eval_interval_fun pp f2 test_vars;;

test 100 (eval_interval_fun pp f) test_vars;;
test 100 (eval_interval_fun pp f2) test_vars;;
*)

(* Replaces the given subexpression with the given reference index
   in all interval functions in the list.
   Returns the number of replaces and a new list of interval functions *)
let replace_subexpr expr expr_index f_list =
  let rec replace f =
    if f = expr then
      1, Int_ref expr_index
    else
      match f with
	| Int_pow (k, f1) ->
	    let c, f1' = replace f1 in
	      c, Int_pow (k, f1')
	| Int_unary (tm, f1) ->
	    let c, f1' = replace f1 in
	      c, Int_unary (tm, f1')
	| Int_binary (tm, f1, f2) ->
	    let c1, f1' = replace f1 in
	    let c2, f2' = replace f2 in
	      c1 + c2, Int_binary (tm, f1', f2')
	| _ -> 0, f in
  let cs, fs = unzip (map replace f_list) in
    itlist (+) cs 0, fs;;


		
let is_leaf f =
  match f with
    | Int_pow _ -> false
    | Int_unary _ -> false
    | Int_binary _ -> false
    | _ -> true;;

let find_and_replace_all f_list acc =
  let rec find_and_replace f i f_list =
    if is_leaf f then
      f, (0, f_list)
    else
      let expr, (c, fs) =
	match f with
	  | Int_pow (k, f1) -> find_and_replace f1 i f_list
	  | Int_unary (tm, f1) -> find_and_replace f1 i f_list
	  | Int_binary (tm, f1, f2) ->
	      let expr, (c1, fs) = find_and_replace f1 i f_list in
		if c1 > 1 then expr, (c1, fs) else find_and_replace f2 i f_list
	  | _ -> f, (0, f_list) in
	if c > 1 then expr, (c, fs) else f, replace_subexpr f i f_list in
    
  let rec iterate fs acc =
    let i = length acc in
    let expr, (c, fs') = find_and_replace (hd fs) i fs in
      if c > 1 then iterate fs' (acc @ [expr]) else fs, acc in

  let rec iterate_all f_list ref_acc f_acc =
    match f_list with
      | [] -> f_acc, ref_acc
      | f :: fs ->
	  let fs', acc' = iterate f_list ref_acc in
	    iterate_all (tl fs') acc' (f_acc @ [hd fs']) in

    iterate_all f_list acc [];;


let eval_interval_fun_list pp (f_list, refs) vars =
  let rec eval_refs refs acc =
    match refs with
      | [] -> acc
      | r :: rs ->
	  let v = eval_interval_fun pp r vars acc in
	    eval_refs rs (acc @ [v]) in
  let rs = eval_refs refs [] in
    map (fun f -> eval_interval_fun pp f vars rs) f_list;;


(*
let pp = 5;;
let test_vars = [`x:real`, two_interval];;
let test_expr1 = `x * x + (&3 * x) * x * x + &3 * x + &3 * x`;;
let test_expr2 = `(x * x) * (x * &2) + x * &2`;;
let subexpr1 = `x * x` and subexpr2 = `&3 * x`;;
let test_f1 = build_interval_fun test_expr1 and
    test_f2 = build_interval_fun test_expr2;;
let sub1 = build_interval_fun subexpr1 and sub2 = build_interval_fun subexpr2;;


let v = find_and_replace_all [test_f1; test_f2] [];;

eval_interval_fun_list pp v test_vars;;
map (fun f -> eval_interval_fun pp f test_vars []) [test_f1; test_f2];;

(* 0.712 *)
test 100 (map (fun f -> eval_interval_fun pp f test_vars [])) [test_f1; test_f2];;
(* 0.432 *)
test 100 (eval_interval_fun_list pp v) test_vars;;
*)