needs "../formal_lp/hypermap/contravening_ineqs.hl";;
needs "../formal_lp/hypermap/list_hypermap_computations.hl";;
needs "../formal_lp/more_arith/prove_lp.hl";;

open Linear_function;;
open Prove_lp;;
open Arith_nat;;
open Misc_vars;;

let REAL_ADD_ASSOC' = (SYM o SPEC_ALL) REAL_ADD_ASSOC;;

(* Performs the following conversions:
   (a + ... + c) + d = a + ... + c + d *)
let rec plus_assoc_conv tm = 
  if (is_binop add_op_real tm) then
    let lhs, rhs = dest_binop add_op_real tm in
      if (is_binop add_op_real lhs) then
	let x_tm, y_tm = dest_binop add_op_real lhs in
	let th0 = INST[x_tm, x_var_real; y_tm, y_var_real; rhs, z_var_real] REAL_ADD_ASSOC' in
	let ltm, rtm = dest_comb(rand(concl th0)) in
	  TRANS th0 (AP_TERM ltm (plus_assoc_conv rtm))
      else
	REFL tm
  else
    REFL tm;;

(*
let tm = `(&1 + x + y + &2) + (z + t)`;;
plus_assoc_conv tm;;
(* 0.252 *)
test 1000 (REWRITE_CONV[GSYM REAL_ADD_ASSOC]) tm;;
(* 0.036 *)
test 1000 plus_assoc_conv tm;;
*)




let prove_hypermap_lp hyp_str precision constraints target_var_bounds var_bounds =
  let list_hyp, list_thm, fun_table = compute_all hyp_str in
    
  let table_set_rewrites =
    let hyp = Hashtbl.create 10 in
    let _ = map (fun set, name -> Hashtbl.add hyp set (Hashtbl.find list_thm name))
      [
	"list_of_darts", "darts";
	"list_of_darts3", "darts3";
	"list_of_darts4", "darts4";
	"list_of_dartsX", "dartsX";
	"list_of_nodes", "nodes";
	"list_of_edges", "edges";
	"list_of_faces", "faces";
	"list_of_faces3", "faces3";
	"list_of_faces4", "faces4";
	"list_of_faces5", "faces5";
	"list_of_faces6", "faces6";
      ] in
      hyp in

  let l_var_list = `L:((num)list)list` in

  (* Rewrites subterms in the inequality *)
  let rewrite_ineq ineq =
    let rec rewrite_lhs = fun tm ->
      let rewrite_one = fun tm ->
	if (is_binop mul_op_real tm) then
	  let mul_tm, var_tm = dest_comb tm in
	  let var_f, arg = dest_comb var_tm in
	  
	  let rec convert_arg = fun arg ->
	    if (is_comb arg) then
	      let ltm, sub_arg' = dest_comb arg in
	      let const_name = (fst o dest_const) (if (is_const ltm) then ltm else rator ltm) in
		if (const_name = "CONS" or const_name = ",") then
		  REFL arg
		else
		  try
		    let sub_arg_th = convert_arg sub_arg' in
		    let th0 = AP_TERM ltm sub_arg_th in
		    let rtm = rand(concl th0) in
		    let th1 = 
		      if (const_name = "set_of_list") then
			set_of_list_conv rtm
		      else if (const_name = "FST") then
			fst_conv rtm
		      else if (const_name = "SND") then
			snd_conv rtm
		      else
			let table = Hashtbl.find fun_table const_name in
			  Hashtbl.find table (rand rtm) in
		      TRANS th0 th1
		  with e ->
		    failwith ("convert_arg: "^const_name)
	    else
	      REFL arg in
	    
	  let arg_th = convert_arg arg in
	    AP_TERM mul_tm (AP_TERM var_f arg_th)
	      
	else
	  (* tm should be list_sum *)
	  list_sum_conv BETA_CONV tm in
	
	if (is_binop add_op_real tm) then
	  let lhs, rhs = dest_binop add_op_real tm in
	  let lhs_th = rewrite_one lhs in
	  let rhs_th = rewrite_lhs rhs in
	  let th1 = MK_COMB(AP_TERM add_op_real lhs_th, rhs_th) in
	    if (is_binop add_op_real (rand(concl lhs_th))) then
	      let th2 = plus_assoc_conv (rand(concl th1)) in
		TRANS th1 th2
	    else
	      th1
	else
	  rewrite_one tm in

    let th0 = BETA_RULE ineq in
    let ltm, rtm = dest_comb(concl th0) in
    let op_tm, l_tm = dest_comb ltm in
    let lhs_th = rewrite_lhs l_tm in
      EQ_MP (AP_THM (AP_TERM op_tm lhs_th) rtm) th0 in
	
  (* This function generates all inequalities from a given base inequality *)
  let get_ineqs = fun ineq indices ->
    let t0 = INST[list_hyp, l_var_list] ineq in
    let t1 = MY_PROVE_HYP (Hashtbl.find list_thm "good_list") t0 in
    let all_tm, set_tm = dest_comb (concl t1) in
    let set_th = Hashtbl.find table_set_rewrites ((fst o dest_const o rator) set_tm) in
    let t2 = EQ_MP (AP_TERM all_tm set_th) t1 in
    let ths = select_all t2 indices in
      map rewrite_ineq ths in


  let precision_constant = Int 10 **/ (Int precision) and
      target_bound = `&12` in

  (* This function generates all inequalities with the given name and indices,
     multiplies these inequalities by given coefficients, and adds up the obtained
     inequalities *)
  let sum_step = fun (name, indices, c) ->
    try
      let ineq = find_ineq precision name in
      let ineqs = get_ineqs ineq indices in
      let s1 = map transform_le_ineq (zip ineqs c) in
	List.fold_left add_step' dummy s1
    with e ->
      failwith ("Problem: "^name) in

  (* Find all sums *)
  let s1' = List.fold_left add_step' dummy (map sum_step constraints) in
  let s1 = mul_step s1' (mk_real_int precision_constant) in
  let s2 = List.fold_left add_step' dummy (map sum_step target_var_bounds) in
  let s3 = List.fold_left add_step' dummy (map sum_step var_bounds) in
  let s4 = add_step' (add_step' s1 s2) s3 in

  (* Final transformations *)
  let r6 = CONV_RULE (DEPTH_CONV NUM_TO_NUMERAL_CONV) s4 in
  let m = term_of_rat (precision_constant */ precision_constant */ precision_constant) in
  let r7 = mul_rat_step r6 (mk_comb (`(/) (&1)`, m)) in
  let r8 = REWRITE_RULE[lin_f; ITLIST; REAL_ADD_RID] r7 in
  let r9 = EQT_ELIM (REAL_RAT_LE_CONV (mk_binop le_op_real ((rand o concl) r8) target_bound)) in
    MATCH_MP REAL_LE_TRANS (CONJ r8 r9);;