needs "../formal_lp/arith/informal/informal_arith.hl";;

module Informal_eval_interval = struct

open Informal_interval;;
open Informal_float;;
open Informal_atn;;

(* Creates an interval approximation of the given decimal term *)
let mk_float_interval_decimal pp decimal_tm =
  let n_tm, d_tm = dest_binary "DECIMAL" decimal_tm in
  let n, d = dest_numeral n_tm, dest_numeral d_tm in
  let n_int, d_int = mk_num_interval n, mk_num_interval d in
    div_interval pp n_int d_int;;


(* Unary interval operations *)
let unary_interval_operations = 
  let table = Hashtbl.create 10 in
  let add = Hashtbl.add in
    add table "real_neg" (fun pp -> neg_interval);
    add table "real_inv" inv_interval;
    add table "sqrt" sqrt_interval;
    add table "atn" atn_interval;
    add table "acs" acs_interval;
    table;;


(* Binary interval operations *)
let binary_interval_operations = 
  let table = Hashtbl.create 10 in
  let add = Hashtbl.add in
    add table "real_add" add_interval;
    add table "real_sub" sub_interval;
    add table "real_mul" mul_interval;
    add table "real_div" div_interval;
    table;;


(* Interval approximations of constants *)
let interval_constants =
  let table = Hashtbl.create 10 in
  let add = Hashtbl.add in
    add table "pi" (fun pp -> pi_approx_array.(pp));
    table;;



(* Type of an interval function *)
type interval_fun =
  | Int_ref of int
  | Int_var of int
  | Int_const of interval
  | Int_decimal_const of term
  | Int_named_const of string
  | Int_pow of int * interval_fun
  | Int_unary of string * interval_fun
  | Int_binary of string * interval_fun * interval_fun;;


(* Equality of interval functions *)
let rec eq_ifun ifun1 ifun2 =
  match (ifun1, ifun2) with
    | (Int_ref r1, Int_ref r2) -> r1 = r2
    | (Int_var v1, Int_var v2) -> v1 = v2
    | (Int_decimal_const tm1, Int_decimal_const tm2) -> tm1 = tm2
    | (Int_named_const name1, Int_named_const name2) -> name1 = name2
    | (Int_pow (n1, f1), Int_pow (n2, f2)) -> n1 = n2 && eq_ifun f1 f2
    | (Int_unary (op1, f1), Int_unary (op2, f2)) -> op1 = op2 && eq_ifun f1 f2
    | (Int_binary (op1, f1, g1), Int_binary (op2, f2, g2)) -> op1 = op2 && eq_ifun f1 f2 && eq_ifun g1 g2
    | (Int_const int1, Int_const int2) ->
	let lo1, hi1 = dest_interval int1 and
	    lo2, hi2 = dest_interval int2 in
	  eq_float lo1 lo2 && eq_float hi1 hi2
    | _ -> false;;


(* Evaluates the given interval function at the point
   defined by the given list of variables *)
let eval_interval_fun =
  let u_find = Hashtbl.find unary_interval_operations and
      b_find = Hashtbl.find binary_interval_operations and
      c_find = Hashtbl.find interval_constants in
    fun pp ifun vars refs ->
      let rec rec_eval f =
	match f with
	  | Int_ref i -> List.nth refs i
	  | Int_var i -> List.nth vars (i - 1)
	  | Int_const int -> int
	  | Int_decimal_const tm -> mk_float_interval_decimal pp tm
	  | Int_named_const name -> (c_find name) pp
	  | Int_pow (n,f1) -> pow_interval pp n (rec_eval f1)
	  | Int_unary (op,f1) -> (u_find op) pp (rec_eval f1)
	  | Int_binary (op,f1,f2) -> (b_find op) pp (rec_eval f1) (rec_eval f2) in
	rec_eval ifun;;


(* Evaluates all sub-expressions involving constants in the given interval function *)
let eval_constants =
  let u_find = Hashtbl.find unary_interval_operations and
      b_find = Hashtbl.find binary_interval_operations and
      c_find = Hashtbl.find interval_constants in
    fun pp ifun ->
      let rec rec_eval f =
	match f with
	  | Int_decimal_const tm -> Int_const (mk_float_interval_decimal pp tm)
	  | Int_named_const name -> Int_const (c_find name pp)
	  | Int_pow (n, f1) -> 
	      (let f1_val = rec_eval f1 in
		 match f1_val with
		   | Int_const int -> Int_const (pow_interval pp n int)
		   | _ -> Int_pow (n,f1_val))
	  | Int_unary (op, f1) ->
	      (let f1_val = rec_eval f1 in
		 match f1_val with
		   | Int_const int -> Int_const (u_find op pp int)
		   | _ -> Int_unary (op, f1_val))
	  | Int_binary (op, f1, f2) ->
	      (let f1_val, f2_val = rec_eval f1, rec_eval f2 in
		 match f1_val with
		   | Int_const int1 ->
		       (match f2_val with
			  | Int_const int2 -> Int_const (b_find op pp int1 int2)
			  | _ -> Int_binary (op, f1_val, f2_val))
		   | _ -> Int_binary (op, f1_val, f2_val))
	  | _ -> f in
	rec_eval ifun;;



(**************************************)

(* Builds an interval function from the given term expression *)
let build_interval_fun =
  let amp_op_real = `(&):num -> real` in
  let rec rec_build expr_tm =
    if is_const expr_tm then
      (* Constant *)
      Int_named_const (fst (dest_const expr_tm))
    else if is_var expr_tm then
      (* Variables should be of the form name$i *)
      failwith ("Variables should be of the form name$i: " ^ string_of_term expr_tm)
    else
      let ltm, r_tm = dest_comb expr_tm in
	(* Unary operations *)
	if is_const ltm then
	  (* & *)
	  if ltm = amp_op_real then
	    let n = dest_numeral r_tm in
	      Int_const (mk_num_interval n)
	  else 
	    let r_fun = rec_build r_tm in
	      Int_unary ((fst o dest_const) ltm, r_fun)
	else
	  (* Binary operations *)
	  let op, l_tm = dest_comb ltm in
	  let name = (fst o dest_const) op in
	    if name = "DECIMAL" then
	      (* DECIMAL *)
	      Int_decimal_const expr_tm
	    else if name = "real_pow" then
	      (* pow *)
	      let n = dest_small_numeral r_tm in
		Int_pow (n, rec_build l_tm)
	    else if name = "$" then
	      (* $ *)
	      Int_var (dest_small_numeral (rand expr_tm))
	    else
	      let lhs = rec_build l_tm and
		  rhs = rec_build r_tm in
		Int_binary ((fst o dest_const) op, lhs, rhs) in
    rec_build;;


(* Replaces the given subexpression with the given reference index
   for all interval functions in the list.
   Returns the number of replaces and a new list of interval functions *)
let replace_subexpr expr expr_index f_list =
  let rec replace f =
    if eq_ifun f expr then
      1, Int_ref expr_index
    else
      match f with
	| Int_pow (k, f1) ->
	    let c, f1' = replace f1 in
	      c, Int_pow (k, f1')
	| Int_unary (op, f1) ->
	    let c, f1' = replace f1 in
	      c, Int_unary (op, f1')
	| Int_binary (op, f1, f2) ->
	    let c1, f1' = replace f1 in
	    let c2, f2' = replace f2 in
	      c1 + c2, Int_binary (op, f1', f2')
	| _ -> 0, f in
  let cs, fs = unzip (map replace f_list) in
    itlist (+) cs 0, fs;;


		
let is_leaf f =
  match f with
    | Int_pow _ -> false
    | Int_unary _ -> false
    | Int_binary _ -> false
    | _ -> true;;

let find_and_replace_all f_list acc =
  let rec find_and_replace f i f_list =
    if is_leaf f then
      f, (0, f_list)
    else
      let expr, (c, fs) =
	match f with
	  | Int_pow (k, f1) -> find_and_replace f1 i f_list
	  | Int_unary (op, f1) -> find_and_replace f1 i f_list
	  | Int_binary (op, f1, f2) ->
	      let expr, (c1, fs) = find_and_replace f1 i f_list in
		if c1 > 1 then expr, (c1, fs) else find_and_replace f2 i f_list
	  | _ -> f, (0, f_list) in
	if c > 1 then expr, (c, fs) else f, replace_subexpr f i f_list in
    
  let rec iterate fs acc =
    let i = length acc in
    let expr, (c, fs') = find_and_replace (hd fs) i fs in
      if c > 1 then iterate fs' (acc @ [expr]) else fs, acc in

  let rec iterate_all f_list ref_acc f_acc =
    match f_list with
      | [] -> f_acc, ref_acc
      | f :: fs ->
	  let fs', acc' = iterate f_list ref_acc in
	    iterate_all (tl fs') acc' (f_acc @ [hd fs']) in

    iterate_all f_list acc [];;


let eval_interval_fun_list pp (f_list, refs) vars =
  let rec eval_refs refs acc =
    match refs with
      | [] -> acc
      | r :: rs ->
	  let v = eval_interval_fun pp r vars acc in
	    eval_refs rs (acc @ [v]) in
  let rs = eval_refs refs [] in
    map (fun f -> eval_interval_fun pp f vars rs) f_list;;


end;;


(*
(* Tests *)
needs "../formal_lp/formal_interval/eval_interval.hl";;

let pp = 7;;

let var_tm = `&1 / #7.1`;;
let var = eval_interval_fun pp (build_interval_fun var_tm) [] [];;
let var0 = Informal_eval_interval.eval_interval_fun pp 
  (Informal_eval_interval.build_interval_fun var_tm) [] [];;

let test_vars = [`(x:real^1)$1`, var];;
let test_vars0 = [var0];;
let test_expr1 = `(x:real^1)$1 * x$1 + (&3 * x$1) * x$1 * x$1 + &3 * x$1 + &3 * x$1`;;
let test_expr2 = `((x:real^1)$1 * x$1) * (x$1 * &2) + x$1 * &2`;;
let subexpr1 = `(x:real^1)$1 * x$1` and subexpr2 = `&3 * (x:real^1)$1`;;

let test_f1 = build_interval_fun test_expr1 and
    test_f2 = build_interval_fun test_expr2;;
let sub1 = build_interval_fun subexpr1 and sub2 = build_interval_fun subexpr2;;

let f1, f2, s1, s2 =
  let build = Informal_eval_interval.build_interval_fun in
    build test_expr1, build test_expr2, build subexpr1, build subexpr2;;

let v = find_and_replace_all [test_f1; test_f2] [];;
let v0 = Informal_eval_interval.find_and_replace_all [f1; f2] [];;

let test_dest int =
  let lo, hi = Informal_interval.dest_interval int in
    Informal_float.dest_float lo, Informal_float.dest_float hi;;

eval_interval_fun_list pp v test_vars;;
map test_dest (Informal_eval_interval.eval_interval_fun_list pp v0 test_vars0);;

map (fun f -> eval_interval_fun pp f test_vars []) [test_f1; test_f2];;
let r = map (fun f -> Informal_eval_interval.eval_interval_fun pp f test_vars0 []) [f1; f2];;
map test_dest r;;


(* 0.712 *)
test 100 (map (fun f -> eval_interval_fun pp f test_vars [])) [test_f1; test_f2];;
(* 0.432 *)
test 100 (eval_interval_fun_list pp v) test_vars;;

test 100 (map (fun f -> Informal_eval_interval.eval_interval_fun pp f test_vars0 [])) [f1; f2];;
test 100 (Informal_eval_interval.eval_interval_fun_list pp v0) test_vars0;;
*)