(* Dependencies *)
needs "../formal_lp/arith/informal/informal_m_taylor.hl";;
needs "../formal_lp/formal_interval/interval_m/verifier.hl";;

module Informal_verifier = struct

open Informal_float;;
open Informal_interval;;
open Informal_taylor;;
open Recurse;;


type verification_funs =
{
  (* p_lin -> p_second -> dom -> ti *)
  taylor : int -> int -> m_cell_domain -> m_taylor_interval;
  (* pp -> xx -> zz -> interval *)
  f : int -> ifloat list -> ifloat list -> interval;
  (* j -> pp -> xx -> zz -> interval *)
  df : int -> int -> ifloat list -> ifloat list -> interval;
  (* i j -> pp -> xx -> zz -> interval *)
  ddf : int -> int -> int -> ifloat list -> ifloat list -> interval;
};;


(* m_subset_interval *)
let m_subset_interval a b c d =
  let prove_le l1 l2 = itlist2 (fun x y r -> le_float x y && r) l1 l2 true in
    prove_le a c && prove_le d b;;

(* m_taylor_cell_pass *)
let m_taylor_cell_pass pp ti =
  let upper = eval_m_taylor_upper_bound pp ti in
    lt0_float upper;;

(* m_taylor_cell_pass0 *)
let m_taylor_cell_pass0 int =
  (lt0_float o snd o dest_interval) int;;

(* m_cell_pass_subdomain *)
let m_cell_pass_subdomain domain2 pass_domain =
  let a, b = pass_domain.lo, pass_domain. hi in
  let c, d = domain2.lo, domain2.hi in
    m_subset_interval a b c d;;

(* m_incr_pass *)
let m_incr_pass pp j ti =
  let partial_bound = eval_m_taylor_partial_lower pp j ti in
    ge0_float partial_bound;;

(* m_decr_pass *)
let m_decr_pass pp j ti =
  let partial_bound = eval_m_taylor_partial_upper pp j ti in
    le0_float partial_bound;;

(* m_mono_pass_gen *)
let m_mono_pass_gen decr_flag bound =
  (if decr_flag then le0_float else ge0_float) bound;;

(* m_convex_pass *)
let m_convex_pass int =
  (ge0_float o fst o dest_interval) int;;


(* mk_verification_functions *)
let mk_verification_functions pp0 f partials partials2 =
  let n = length partials in
  let taylor = eval_m_taylor pp0 f partials partials2 in
  let eval0 = mk_eval_function pp0 f in
  let eval1 = map (fun i -> mk_eval_function pp0 ((rand o concl o List.nth partials) (i - 1))) (1--n) in
  let eval2 = map (fun i ->
		     map (fun j -> 
			    let d2 = List.nth (List.nth partials2 (i - 1)) (j - 1) in
			      mk_eval_function pp0 ((rand o concl) d2)) (1--i)) (1--n) in
    {
      taylor = taylor;
      f = eval0;
      df = (fun i -> List.nth eval1 (i - 1));
      ddf = (fun i j -> List.nth (List.nth eval2 (j - 1)) (i - 1));
    };;


(* split_domain *)
let split_domain pp j domain = 
  let n = length domain.w in
  let t = List.nth domain.y (j - 1) in
  let vv = map (fun i -> if i = j then t else List.nth domain.hi (i - 1)) (1--n) in
  let uu = map (fun i -> if i = j then t else List.nth domain.lo (i - 1)) (1--n) in
    mk_m_center_domain pp domain.lo vv, mk_m_center_domain pp uu domain.hi;;
  

(* restrict_domain *)
let restrict_domain j left_flag domain =
  let replace list j v = map (fun i -> if i = j then v else List.nth list (i - 1)) (1--length list) in
  let t = List.nth (if left_flag then domain.lo else domain.hi) (j - 1) in
  let lo = if left_flag then domain.lo else replace domain.lo j t in
  let hi = if left_flag then replace domain.hi j t else domain.hi in
  let w = replace domain.w j float_0 in
  let y = replace domain.y j t in
    {lo = lo; hi = hi; w = w; y = y};;


(*****************************)
let verbatim = ref true;;

(* m_verify_raw *)
(* Constructs a p_result_tree from the given result_tree *)
let m_verify_raw p_split p_min p_max fs certificate domain0 ref_list =
  let r_size = result_size certificate in
  let k = ref 0 in
  let ps = p_min -- p_max in

  (* finds an optimal precision value *)
  let rec find_p p_fun p_list =
    match p_list with
      | [] -> failwith "find_p: no good p found"
      | p :: ps -> 
	  let flag = (try p_fun p with Failure _ -> false) in
	    if flag then 
	      let _ = if !verbatim then report (sprintf "p = %d" p) else () in p
	    else find_p p_fun ps in

  (* pass_test *)
  let pass_test domain f0_flag pp =
    if f0_flag then
      m_taylor_cell_pass0 (fs.f pp domain.lo domain.hi)
    else
      m_taylor_cell_pass pp (fs.taylor pp pp domain) in

  (* glue_test *)
  let glue_test domain i convex_flag pp =
    if convex_flag then
      m_convex_pass (fs.ddf (i + 1) (i + 1) pp domain.lo domain.hi)
    else
      true in

  (* mono_test *)
  let mono_test mono domain domains pp =
    let xx, zz = domain.lo, domain.hi in
    let taylor = fs.taylor pp pp domain in
    let gen_mono m =
      if m.df0_flag then
	if m.decr_flag then
	  (snd o dest_interval) (fs.df m.variable pp xx zz)
	else
	  (fst o dest_interval) (fs.df m.variable pp xx zz)
      else
	if m.decr_flag then
	  eval_m_taylor_partial_upper pp m.variable taylor
	else
	  eval_m_taylor_partial_lower pp m.variable taylor in
    let monos = map gen_mono mono in
      rev_itlist (fun (m, bound) pass -> 
		    let flag = m.decr_flag in
		      m_mono_pass_gen flag bound && pass) (rev (zip mono monos)) true in

  (* mk_domains *)
  let rec mk_domains mono dom0 acc =
    match mono with
      | [] -> rev acc
      | m :: ms ->
	  let j, flag = m.variable, m.decr_flag in
	  let dom = restrict_domain j flag dom0 in
	    mk_domains ms dom (dom :: acc) in

  (* rec_verify *)
  let rec rec_verify domain certificate =
    match certificate with
      | Result_mono (mono, r1) ->
	  let _ = 
	    if !verbatim then
	      let mono_strs = 
		map (fun m -> sprintf "%s%d (%b)" (if m.decr_flag then "-" else "") 
		       m.variable m.df0_flag) mono in
		report (sprintf "Mono: [%s]" (String.concat ";" mono_strs))
	    else () in
	  let domains = mk_domains mono domain [] in
	  let tree1 = rec_verify (last domains) r1 in
	    (try
	       let pp = find_p (mono_test mono domain domains) ps in
		 P_result_mono ({pp = pp}, mono, tree1)
	     with Failure _ -> failwith "mono: failed")

      | Result_pass (f0_flag, _, _) -> 
	  let _ = k := !k + 1 in
	  let _ = 
	    if !verbatim then
	      report (sprintf "Verifying: %d/%d (f0_flag = %b)" !k r_size f0_flag)
	    else () in
	    (try
	       let pp = find_p (pass_test domain f0_flag) ps in
		 P_result_pass ({pp = pp}, f0_flag)
	     with Failure _ -> failwith "pass: failed")

      | Result_glue (i, convex_flag, r1, r2) ->
	  let domain1, domain2 =
	    if convex_flag then
	      let d1 = restrict_domain (i + 1) true domain in
	      let d2 = restrict_domain (i + 1) false domain in
		d1, d2
	    else
	      split_domain p_split (i + 1) domain in
	  let tree1 = rec_verify domain1 r1 in
	  let tree2 = rec_verify domain2 r2 in
	    (try
	       let pp = find_p (glue_test domain i convex_flag) ps in
		 P_result_glue ({pp = pp}, i, convex_flag, tree1, tree2)
	     with Failure _ -> failwith "glue: failed")

      | Result_pass_ref i ->
	  let _ = if !verbatim then report (sprintf "Ref: %d" i) else () in
	  let pass_flag =
	    if i > 0 then
	      let _ = List.nth ref_list (i - 1) in
		true
	    else
	      let pass_domain = List.nth ref_list (-i - 1) in
		m_cell_pass_subdomain domain pass_domain in
	    if not pass_flag then
	      failwith "ref: failed"
	    else
	      P_result_ref i

      | _ -> failwith "False result" in
    
    rec_verify domain0 certificate;;



(*****************)

(* m_verify_raw0 *)
let m_verify_raw0 p_split p_min p_max fs certificate xx zz =
  m_verify_raw p_split p_min p_max fs certificate (mk_m_center_domain p_split xx zz) [];;

	    
(* m_verify_list *)
let m_verify_list p_split p_min p_max fs certificate_list xx zz =
  let domain_hash = Hashtbl.create (length certificate_list * 10) in
  let mem, find, add = Hashtbl.mem domain_hash, 
    Hashtbl.find domain_hash, Hashtbl.add domain_hash in

  let get_m_cell_domain pp domain0 path =
    let rec get_rec domain path hash =
      match path with
	| [] -> domain
	| (s, j) :: ps ->
	    let hash' = hash^s^(string_of_int j) in
	      if mem hash' then 
		get_rec (find hash') ps hash'
	      else
		if s = "l" or s = "r" then
		  let domain1, domain2 = split_domain pp j domain in
		  let hash1 = hash^"l"^(string_of_int j) and
		      hash2 = hash^"r"^(string_of_int j) in
		  let _ = add hash1 domain1; add hash2 domain2 in
		    if s = "l" then
		      get_rec domain1 ps hash'
		    else
		      get_rec domain2 ps hash'
		else
		  let l_flag = (s = "ml") in
		  let domain' = restrict_domain j l_flag domain in
		  let _ = add hash' domain' in
		    get_rec domain' ps hash' in
      get_rec domain0 path "" in

  let domain0 = mk_m_center_domain p_split xx zz in
  let size = length certificate_list in
  let k = ref 0 in
  let rec rec_verify certificate_list dom_list tree_list =
    match certificate_list with
      | [] -> rev tree_list
      | (path, certificate) :: cs ->
	  let _ = k := !k + 1; report (sprintf "List: %d/%d" !k size) in
	  let domain = get_m_cell_domain p_split domain0 path in
	  let tree = m_verify_raw p_split p_min p_max fs certificate domain dom_list in
	    rec_verify cs (dom_list @ [domain]) ((path, tree) :: tree_list) in
    rec_verify certificate_list [] [];;

end;;