needs "../formal_lp/formal_interval/m_taylor_arith.hl";;
needs "../formal_lp/formal_interval/m_verifier.hl";;


let pp = 10;;
let n = 2;;

let poly1 = expr_to_vector_fun `x1 * x2`;;
let poly2 = expr_to_vector_fun `x2 + x1 * x2`;;

let (_, _, _, eval1), tf1 = mk_verification_functions pp poly1 false `&0` and
    (_, _, _, eval2), tf2 = mk_verification_functions pp poly2 false `&0`;;


let xx = `[&1; &1]` and
    zz = `[&2; &2]`;;

let xx1 = convert_to_float_list pp true xx and
    zz1 = convert_to_float_list pp false zz;;

let xx_float, zz_float =
  let pad = replicate 0.0 (8 - length (dest_list xx1)) in
    map float_of_float_tm (dest_list xx1) @ pad,
    map float_of_float_tm (dest_list zz1) @ pad;;

let domain_th = mk_m_center_domain n pp xx1 zz1;;

evalf (Uni_compose (uatan, tf2)) xx_float zz_float;;
eval_m_taylor_mul n pp (eval1 domain_th) (eval2 domain_th);;
evalf (Product (tf1, tf2)) xx_float zz_float;;



(******************************)

(* dummy functions *)

let eval_d0 i t1 t2 = failwith "eval_d0";;
let eval_dd0 i j t1 t2 = failwith "eval_dd0";;
let eval_0 t1 t2 = failwith "eval_0";;


(******************************)


(* 1d test *)

let n = 1;;
let pp = 5;;

let eval_x = eval_m_taylor_poly0 pp `\x:real^1. x$1`;;

let f_eval th =
  let x = eval_x th in
  let ( - ), ( * ), atn = eval_m_taylor_sub n pp, eval_m_taylor_mul n pp, eval_m_taylor_atn n pp in
    x * x - atn x;;

let tf = Plus (Product (x1, x1), Scale (Uni_compose (uatan, x1), ineg one));;


let xx = `[#0.01]` and
    zz = `[#0.8]`;;

let xx1 = convert_to_float_list pp true xx and
    zz1 = convert_to_float_list pp false zz;;

let xx_float, zz_float =
  let pad = replicate 0.0 (8 - length (dest_list xx1)) in
    map float_of_float_tm (dest_list xx1) @ pad,
    map float_of_float_tm (dest_list zz1) @ pad;;

let domain_th = mk_m_center_domain n pp xx1 zz1;;
f_eval domain_th;;
evalf tf xx_float zz_float;;

let certificate = run_test tf xx_float zz_float false 0.0 true false true false 0.0;;
let c1 = transform_result xx_float zz_float certificate;;

m_verify_list n pp (eval_d0, eval_dd0, eval_0, f_eval) c1 xx1 zz1;;


(********************)

let n = 1;;
let pp = 10;;

let atan_num = `let P00, P01, P02 = #4.26665376811246382, #3.291955407318148, #0.281939359003812 in
  let x1 = x * x in
  let x2 = x1 * x1 in
    x * (x1 * P01 + x2 * P02 + P00)`;;

let atan_denom = `let Q00, Q01 = #4.26665376811354027, #4.714173328637605 in
    let x1 = x * x in
    let x2 = x1 * x1 in
      x1 * Q01 + x2 + Q00`;;


let num_poly = (expr_to_vector_fun o rand o concl o REPEATC let_CONV) atan_num;;
let denom_poly = (expr_to_vector_fun o rand o concl o REPEATC let_CONV) atan_denom;;

let (_, _, _, eval_num), tf_num = mk_verification_functions pp num_poly false `&0`;;
let (_, _, _, eval_denom), tf_denom = mk_verification_functions pp denom_poly false `&0`;;

let (_, _, _, eval_eps), tf_eps = mk_verification_functions pp `\x:real^1. #0.00000001` false `&0`;;


let eval_approx_hi th =
  let x = eval_x th in
  let eps = eval_eps th in
  let atn, inv, ( * ), ( - ) = eval_m_taylor_atn n pp, eval_m_taylor_inv n pp,
    eval_m_taylor_mul n pp, eval_m_taylor_sub n pp in
    (eval_num th * inv (eval_denom th) - atn x) - eps;;

(***)
let xx = `[#0.0000001]` and
    zz = `[#0.2]`;;

let xx1 = convert_to_float_list pp true xx and
    zz1 = convert_to_float_list pp false zz;;

let xx_float, zz_float =
  let pad = replicate 0.0 (8 - length (dest_list xx1)) in
    map float_of_float_tm (dest_list xx1) @ pad,
    map float_of_float_tm (dest_list zz1) @ pad;;

let domain_th = mk_m_center_domain n pp xx1 zz1;;


eval_approx_hi domain_th;;
(* 10: 0.800 (p = 15) *)
test 1 eval_approx_hi domain_th;;
    

(***)
let tf_approx_hi = Plus (Plus (Product (tf_num, Uni_compose (uinv, tf_denom)),
			       Scale (Uni_compose (uatan, x1), ineg one)),
			 Scale (tf_eps, ineg one));;


let certificate = run_test tf_approx_hi xx_float zz_float false 0.0 true false true false 0.0;;
result_stat certificate;;
let c1 = transform_result xx_float zz_float certificate;;

m_verify_list n pp (eval_d0, eval_dd0, eval_0, eval_approx_hi) c1 xx1 zz1;;




(* 2d test *)

let n = 2;;
let pp = 10;;

let eval_x = eval_m_taylor_poly0 pp `\x:real^2. x$1` and
    eval_y = eval_m_taylor_poly0 pp `\x:real^2. x$2`;;

let eval_xy = eval_m_taylor_poly0 pp `\x:real^2. x$1 * x$2`;;

let f_eval th =
  let x = eval_x th in
  let y = eval_y th in
  let ( - ), ( * ), atn = eval_m_taylor_sub n pp, eval_m_taylor_mul n pp, eval_m_taylor_atn n pp in
    atn (x * y) - x * y;;

let f_eval2 th =
  let xy = eval_xy th in
  let ( - ), ( * ), atn = eval_m_taylor_sub n pp, eval_m_taylor_mul n pp, eval_m_taylor_atn n pp in
    atn xy - xy;;


let tf = Plus (Uni_compose (uatan, Product (x1, x2)), Scale (Product (x1, x2), ineg one));;


let xx = `[#0.1; #0.1]` and
    zz = `[#0.6; #0.6]`;;

let xx1 = convert_to_float_list pp true xx and
    zz1 = convert_to_float_list pp false zz;;

let xx_float, zz_float =
  let pad = replicate 0.0 (8 - length (dest_list xx1)) in
    map float_of_float_tm (dest_list xx1) @ pad,
    map float_of_float_tm (dest_list zz1) @ pad;;

let domain_th = mk_m_center_domain n pp xx1 zz1;;

let certificate = run_test tf xx_float zz_float false 0.0 true false true false 0.0;;
let c1 = transform_result xx_float zz_float certificate;;

(* 10: 33.926 *)
test 1 (m_verify_list n pp (eval_d0, eval_dd0, eval_0, f_eval) c1 xx1) zz1;;
(* 10: 25.089 *)
test 1 (m_verify_list n pp (eval_d0, eval_dd0, eval_0, f_eval2) c1 xx1) zz1;;



(* 6d tests *)

let n = 6;;
let pp = 10;;

let (_, _, _, eval_xy), tf0 = 
  mk_verification_functions pp `\x:real^6. x$1 * x$2 + x$3 * x$4 + x$5 * x$6` false `&0`;;

let f_eval_sqrt th =
  let xy = eval_xy th in
  let ( - ), ( * ), sqrt = eval_m_taylor_sub n pp, eval_m_taylor_mul n pp, eval_m_taylor_sqrt n pp in
    xy - sqrt xy;;

let f_eval_atn th =
  let xy = eval_xy th in
  let ( - ), ( * ), atn = eval_m_taylor_sub n pp, eval_m_taylor_mul n pp, eval_m_taylor_atn n pp in
    atn xy - xy;;


let tf_sqrt = Plus (tf0, Scale (Uni_compose (usqrt, tf0), ineg one));;
let tf_atn = Plus (Uni_compose (uatan, tf0), Scale (tf0, ineg one));;


let xx = `[#0.1; #0.1; #0.1; #0.1; #0.1; #0.1]` and
    zz = `[#0.2; #0.2; #0.2; #0.2; #0.2; #0.2]`;;

let xx1 = convert_to_float_list pp true xx and
    zz1 = convert_to_float_list pp false zz;;

let xx_float, zz_float =
  let pad = replicate 0.0 (8 - length (dest_list xx1)) in
    map float_of_float_tm (dest_list xx1) @ pad,
    map float_of_float_tm (dest_list zz1) @ pad;;


let certificate_sqrt = run_test tf_sqrt xx_float zz_float false 0.0 true false true false 0.0;;
let c1_sqrt = transform_result xx_float zz_float certificate_sqrt;;

let certificate_atn = run_test tf_atn xx_float zz_float false 0.0 true false true false 0.0;;
let c1_atn = transform_result xx_float zz_float certificate_atn;;

m_verify_list n pp (eval_d0, eval_dd0, eval_0, f_eval_sqrt) c1_sqrt xx1 zz1;;
m_verify_list n pp (eval_d0, eval_dd0, eval_0, f_eval_atn) c1_atn xx1 zz1;;



(*********************************)

(* real tests *)

let n = 6;;
let pp = 15;;
(* let pp = 8;; *)

(* delta_y *)
let delta_y_poly =
  let tm = (rand o concl o SPEC_ALL o REWRITE_RULE[Sphere.delta_x]) Sphere.delta_y in
    expr_to_vector_fun tm;;

(* delta_y4 *)
let delta_y4_poly =
  let tm = (rand o concl o SPECL[`y1*y1`; `y2*y2`; `y3*y3`; `y4*y4`; `y5*y5`; `y6*y6`]) Sphere.delta_x4 in
    expr_to_vector_fun tm;;

(* 4 * y1 * delta_y *)
let four_y1_delta_y_poly = 
  let x_var, tm = dest_abs delta_y_poly in
    mk_abs (x_var, mk_binop mul_op_real `&4` (mk_binop mul_op_real `(x:real^6)$1 * x$1` tm));;

(* - delta_y4 *)
let neg_delta_y4_poly =
  let x_var, tm = dest_abs delta_y4_poly in
    mk_abs (x_var, mk_comb (neg_op_real, tm));;


let (_, _, _, eval_neg_delta_y4), tf_neg_delta_y4 = 
  mk_verification_functions pp neg_delta_y4_poly false `&0`;;

let (_, _, _, eval_4y1_delta_y), tf_4y1_delta_y = 
  mk_verification_functions pp four_y1_delta_y_poly false `&0`;;

let (_, _, _, eval_pi2), tf_pi2 = 
  mk_verification_functions pp `\x:real^6. pi / &2` false `&0`;;


let (_, _, _, eval_pi2_plus), tf_pi2_plus = 
  mk_verification_functions pp `\x:real^6. pi / &2 - #1.893` false `&0`;;



Sphere.dih_x;;

(* dih_y *)
let tf_dih_y_hi =
  let denom = Uni_compose (uinv, Uni_compose (usqrt, tf_4y1_delta_y)) in
    Plus (tf_pi2_plus, Uni_compose (uatan, Product (tf_neg_delta_y4, denom)));;

let eval_dih_y_hi th =
  let inv, atn, sqrt, ( * ), ( + ) = 
    eval_m_taylor_inv n pp, eval_m_taylor_atn n pp, eval_m_taylor_sqrt n pp, 
    eval_m_taylor_mul n pp, eval_m_taylor_add n pp in
  let poly1 = eval_4y1_delta_y th and
      poly2 = eval_neg_delta_y4 th and
      pi2 = eval_pi2_plus th in
    pi2 + atn (poly2 * inv (sqrt (poly1)));;
    


(************)

let xx = `[&2; &2; &2; &2; &2; &2]` and
    zz = `[#2.52; #2.52; #2.52; #2.52; #2.52; #2.52]`;;

let xx1 = convert_to_float_list pp true xx and
    zz1 = convert_to_float_list pp false zz;;


(* Small domain *)
let xx_s = `[&2; &2; &2; &2; &2; &2]` and
    zz_s = `[#2.1; #2.1; #2.1; #2.1; #2.1; #2.1]`;;

let domain_th =
  let xx1_s = convert_to_float_list pp true xx_s and
      zz1_s = convert_to_float_list pp false zz_s in
    mk_m_center_domain n pp xx1_s zz1_s;;

(* 10: 9.121 (pp = 15); 300: 5.5 (pp = 8) *)
test 1 eval_dih_y_hi domain_th;;


(***)

let xx_s = `[&2; &2; &2; &2; &2; &2]` and
    zz_s = `[#2.52; #2.1; #2.1; #2.1; #2.1; #2.1]`;;

let xx1_s = convert_to_float_list pp true xx_s and
    zz1_s = convert_to_float_list pp false zz_s;;


let xx_float, zz_float, xx_s_float, zz_s_float =
  let pad = replicate 0.0 (8 - length (dest_list xx1)) in
    map float_of_float_tm (dest_list xx1) @ pad,
    map float_of_float_tm (dest_list zz1) @ pad,
    map float_of_float_tm (dest_list xx1_s) @ pad,
    map float_of_float_tm (dest_list zz1_s) @ pad;;



(***)

let c_dih_y_s = run_test tf_dih_y_hi xx_s_float zz_s_float false 0.0 true false true false 0.0;;
let c_dih_y = run_test tf_dih_y_hi xx_float zz_float false 0.0 true false true false 0.0;;
let c_dih_y0 = run_test tf_dih_y_hi xx_float zz_float false 0.0 true false false false 0.0;;
let c_dih_y_full = run_test tf_dih_y_hi xx_float zz_float false 0.0 true true true true 0.0;;
let c_dih_y_min = run_test tf_dih_y_hi xx_float zz_float false 0.0 false false false false 0.0;;

(* pass = 4 *)
result_stat c_dih_y_s;;

(* pass = 4292, mono = 8, pass_mono = 2 *)
result_stat c_dih_y;;
(* pass = 4294, mono = 10 *)
result_stat c_dih_y0;;
(* pass = 4292, mono = 8, pass_mono = 2 *)
result_stat c_dih_y_full;;
(* pass = 4317, mono = 0 *)
result_stat c_dih_y_min;;

m_verify_raw0 n pp (eval_d0, eval_dd0, eval_0, eval_dih_y) c_dih_y_s xx1_s zz1_s;;
(* 10: 38.418; 300: 22.289 (pp = 8)*)
test 1 (m_verify_raw0 n pp (eval_d0, eval_dd0, eval_0, eval_dih_y) c_dih_y_s xx1_s) zz1_s;;



m_verify_raw0 n pp (eval_d0, eval_dd0, eval_0, eval_dih_y) c_dih_y0 xx1 zz1;;
let c1_dih_y = transform_result xx_float zz_float c_dih_y;;



(*************************************)
(* dih_y *)
let tf_dih_y =
  let denom = Uni_compose (uinv, Uni_compose (usqrt, tf_4y1_delta_y)) in
    Plus (tf_pi2, Uni_compose (uatan, Product (tf_neg_delta_y4, denom)));;


(* sol_y *)

let _, tf_neg_pi = mk_verification_functions 15 `\x:real^6. --pi` false `&0`;;

Sphere.sol_y;;
let tf_sol_y =
  let dih1 = tf_dih_y and
      dih2 = Composite (tf_dih_y, x2, x3, x1, x5, x6, x4, unit, unit) and
      dih3 = Composite (tf_dih_y, x3, x1, x2, x6, x4, x5, unit, unit) in
    Plus (dih1, Plus (dih2, Plus (dih3, tf_neg_pi)));;

(* taum *)

let pi_val= 3.1415926535897932;;
let sol0_val = 0.5512855984325308;;
let const1_val = sol0_val /. pi_val;;
let const1_plus1 = const1_val +. 1.0;;

let _, tf_ly = mk_verification_functions 15 `\x:real^1. #2.52 / #0.52 - x$1 * inv(#0.52)` false `&0`;;
let tf_lnazim_y = Product (tf_ly, tf_dih_y);;

Sphere.taum;;
let tf_taum =
  let azim1 = tf_lnazim_y and
      azim2 = Composite (tf_lnazim_y, x2, x3, x1, x5, x6, x4, unit, unit) and
      azim3 = Composite (tf_lnazim_y, x3, x1, x2, x6, x4, x5, unit, unit) in
  let f1 = Scale (tf_sol_y, mk_interval (const1_plus1, const1_plus1)) and
      f2 = Plus (azim1, Plus (azim2, azim3)) in
    Plus (f1, Scale (f2, mk_interval (~-.const1_val, ~-.const1_val)));;



(* dih_y < 1.893 *)
let tf_test1 =
  Plus (tf_dih_y, Scale (unit, ineg (mk_interval (1.893, 1.893))));;


(* dih_y > 0.852 *)
let tf_test2 =
  Plus (Scale (unit, mk_interval (0.852, 0.852)), Scale (tf_dih_y, ineg one));;


(* taum + 0.626 * dih_y - 0.77 > 0 *)
let tf_test3 =
  let f = Plus (tf_taum, Scale (tf_dih_y, mk_interval (0.626, 0.626))) in
    Plus (Scale (unit, mk_interval (0.77, 0.77)), Scale (f, ineg one));;


(* taum - 0.259 * dih_y + 0.32 > 0 *)
let tf_test4 =
  let f = Plus (tf_taum, Scale (tf_dih_y, mk_interval (-0.259, -0.259))) in
    Plus (Scale (unit, mk_interval (-0.32, -0.32)), Scale (f, ineg one));;


(* taum - 0.507 * dih_y + 0.724 > 0 *)
let tf_test5 =
  let f = Plus (tf_taum, Scale (tf_dih_y, mk_interval (-0.507, -0.507))) in
    Plus (Scale (unit, mk_interval (-0.724, -0.724)), Scale (f, ineg one));;


(* sol_y - 0.55125 - 0.196 * (y4 + y5 + y6 - 6) + 0.38 * (y1 + y2 + y3 - 6) > 0 *)
let _, tf_poly = mk_verification_functions 15
  `\x:real^6. #0.55125 + #0.196 * (x$4 + x$5 + x$6 - &6) - #0.38 * (x$1 + x$2 + x$3 - &6)` false `&0`;;

let tf_test6 =
  Plus (tf_poly, Scale (tf_sol_y, ineg one));;


(* dih_y - 1.2308 + 0.3639 * (y2 + y3 + y5 + y6 - 8) - 0.235 * (y1 - 2) - 0.685 * (y4 - 2) > 0 *)
let _, tf_poly = mk_verification_functions 15
  `\x:real^6. #1.2308 - #0.3639 * (x$2 + x$3 + x$5 + x$6 - &8) + #0.235 * (x$1 - &2) + #0.685 * (x$4 - &2)`
  false `&0`;;

let tf_test7 =
  Plus (tf_poly, Scale (tf_dih_y, ineg one));;



  
(****)
let c1_1 = run_test tf_test1 xx_float zz_float false 0.0 true false true false 0.0;;
let c1 = run_test tf_test1 xx_float zz_float false 0.0 false false false false 0.0;;
result_stat c1_1;;
(* pass = 4317 *)
result_stat c1;;


(****)
let c2_1 = run_test tf_test2 xx_float zz_float false 0.0 true false true false 0.0;;
let c2 = run_test tf_test2 xx_float zz_float false 0.0 false false false false 0.0;;
(* pass = 8891, mono = 77, pass_mono = 2 *)
result_stat c2_1;;
(* pass = 9007 *)
result_stat c2;;

(***)
let c3 = run_test tf_test3 xx_float zz_float false 0.0 false false false false 0.0 and
    c4 = run_test tf_test4 xx_float zz_float false 0.0 false false false false 0.0 and
    c5 = run_test tf_test5 xx_float zz_float false 0.0 false false false false 0.0;;


(* 48104 *)
result_stat c3;;
(* 34015 *)
result_stat c4;;
(* 25859 *)
result_stat c5;;




(****)
let c6_1 = run_test tf_test6 xx_float zz_float false 0.0 true false true false 0.0;;
let c6 = run_test tf_test6 xx_float zz_float false 0.0 false false false false 0.0;;
(* pass = 67337, mono = 214, pass_mono = 33 *)
result_stat c6_1;;
(* pass = 9007 *)
result_stat c2;;


(****)
let c7_1 = run_test tf_test7 xx_float zz_float false 0.0 true false true false 0.0;;
let c7 = run_test tf_test7 xx_float zz_float false 0.0 false false false false 0.0;;
let c7_full = run_test tf_test7 xx_float zz_float false 0.0 true true true true 0.0;;
(* pass = 21104, mono = 1203, pass_mono = 58 *)
result_stat c7_full;;
(* pass = 21104, mono = 1203, pass_mono = 58 *)
result_stat c7_1;;
(* pass = 23251 *)
result_stat c7;;



(****************)

let rec get_pass_intervals r =
  match r with
    | Result_pass (_, xx, zz) -> [xx, zz]
    | Result_glue (_, _, r1, r2) -> get_pass_intervals r1 @ get_pass_intervals r2
    | _ -> [];;

let sum_rs rs = itlist (+) (map result_size rs) 0;;
let sum1_rs rs = itlist (+) (filter (fun s -> s = 1) (map result_size rs)) 0;;


let c1_ints = get_pass_intervals c1;;
let c2_ints = get_pass_intervals c2;;
let c3_ints = get_pass_intervals c3;;
let c4_ints = get_pass_intervals c4;;
let c5_ints = get_pass_intervals c5;;
let c6_ints = get_pass_intervals c6;;
let c7_ints = get_pass_intervals c7;;

(***)
let c2_1_rs = map (fun (xx, zz) -> 
		   run_test tf_test2 xx zz false 0.0 false false false false 0.0) c1_ints;;

let c2_7_rs = map (fun (xx, zz) -> 
		   run_test tf_test2 xx zz false 0.0 false false false false 0.0) c7_ints;;

(* 9121 vs 9007 *)
sum_rs c2_1_rs;;
(* 2146 *)
sum1_rs c2_1_rs;;

(* 23922 *)
sum_rs c2_7_rs;;
(* 22762 *)
sum1_rs c2_7_rs;;

(***)
let c1_2_rs = map (fun (xx, zz) -> 
		   run_test tf_test1 xx zz false 0.0 false false false false 0.0) c2_ints;;

let c1_7_rs = map (fun (xx, zz) -> 
		   run_test tf_test1 xx zz false 0.0 false false false false 0.0) c7_ints;;


(* 9121 vs 9007 *)
sum_rs c1_2_rs;;
(* 8933 *)
sum1_rs c1_2_rs;;

(* 23268 *)
sum_rs c1_7_rs;;
(* 23247 *)
sum1_rs c1_7_rs;;


(***)
let c7_2_rs = map (fun (xx, zz) -> 
		   run_test tf_test7 xx zz false 0.0 false false false false 0.0) c2_ints;;


(* 23922 vs 23251 *)
sum_rs c7_2_rs;;
(* 3918 *)
sum1_rs c7_2_rs;;


(*********)
let c4_3_rs = map (fun (xx, zz) -> 
		   run_test tf_test4 xx zz false 0.0 false false false false 0.0) c3_ints and
    c5_3_rs = map (fun (xx, zz) -> 
		   run_test tf_test5 xx zz false 0.0 false false false false 0.0) c3_ints;;


(* 54465 vs 34015 *)
sum_rs c4_3_rs;;
(* 43517 *)
sum1_rs c4_3_rs;;


(* 52093 vs 25859 *)
sum_rs c5_3_rs;;
(* 45488 *)
sum1_rs c5_3_rs;;