needs "../formal_lp/hypermap/ineqs/lp_gen_ineqs.hl";;
needs "../formal_lp/hypermap/ineqs/lp_approx_ineqs.hl";;
needs "../formal_lp/hypermap/ineqs/lp_ineqs_defs.hl";;
needs "../formal_lp/hypermap/ineqs/lp_ineqs_proofs-compiled.hl";;
needs "../formal_lp/hypermap/ineqs/lp_ineqs_proofs2-compiled.hl";;
needs "../formal_lp/hypermap/ineqs/lp_main_estimate-compiled.hl";;
needs "../formal_lp/hypermap/arith_link.hl";;

module Lp_ineqs = struct

open Lp_ineqs_proofs;;

type raw_lp_ineq = {
  name : string;
  tm : term;
  proof: thm option;
  std_only: bool;
};;


type lp_ineq = {
  ineq_th : thm;
  pair_flag : bool;
};;


let raw_ineq_list, add_raw_ineq, find_raw_ineq = 
  let table = ref [] in
    (fun () -> !table),
  (fun name std tm th -> 
     table := {name = name; tm = tm; proof = th; std_only = std} :: !table),
  (fun std_flag name -> find (fun x -> x.name = name && x.std_only = std_flag) !table);;


(* Combine all nonlinear inequalities related to linear programs into one definition *)
let lp_ineqs_def =
  let check_lp_tags =
    let rec check tags =
      match tags with
	| Lp :: _ -> true
	| Tablelp :: _ -> true
	| Lp_aux _ :: _ -> true
	| h :: t -> check t
	| [] -> false in
      fun ineq -> check ineq.tags in
  let ineq_ids = ["6170936724"] in
  let lp_ineqs = filter (fun ineq -> check_lp_tags ineq or mem ineq.idv ineq_ids) !Ineq.ineqs in
  let lp_tms = setify (map (fun t -> t.ineq) lp_ineqs) in
  let lp_tm = end_itlist (curry mk_conj) lp_tms in
    new_definition (mk_eq (`lp_ineqs:bool`, lp_tm));;


let add_lp_ineqs_hyp =
  let imp_th = (UNDISCH o MATCH_MP EQ_IMP) lp_ineqs_def in
  let ths = CONJUNCTS imp_th in
    fun th ->
      itlist PROVE_HYP ths th;;


(* Replaces all general hypotheses with lp_cond *)
let add_lp_hyp =
  let lp_fan_ths = (CONJUNCTS o UNDISCH_ALL o MATCH_MP EQ_IMP o SPEC_ALL) Lp_ineqs_proofs.lp_fan in
  let lp_cond_ths = (CONJUNCTS o UNDISCH_ALL o MATCH_MP EQ_IMP o SPEC_ALL o INST_TYPE[`:num`, aty]) Lp_ineqs_proofs.lp_cond in
  let lp_main_estimate_ths = (CONJUNCTS o UNDISCH_ALL o MATCH_MP EQ_IMP o SPEC_ALL) Lp_main_estimate.lp_main_estimate in
    fun lp_ineq_flag th ->
      let th' = (UNDISCH_ALL o SPEC_ALL) th in
      let th0 = if (not lp_ineq_flag) or is_forall (concl th') or is_binary "ALL" (concl th') then th' else th in
      let th1 = itlist PROVE_HYP lp_fan_ths th0 in
      let th2 = itlist PROVE_HYP lp_cond_ths th1 in
	  let th3 = itlist PROVE_HYP lp_main_estimate_ths th2 in
		add_lp_ineqs_hyp th3;;


let parse_lp_ineq s type_hints =
  let hints = map (fun var, ty -> var, pretype_of_type ty) type_hints in
  let ptm, l = (parse_preterm o lex o explode) s in
    if l = [] then
      (term_of_preterm o retypecheck hints) ptm
    else
      failwith "Unparsed input following term";;


let add_lp_ineq_str (name, std_flag, s_tm, th) =
  let tm = parse_lp_ineq s_tm ["H", `:(A)hypermap`; "V", `:C->bool`; "node_mod", `:A->C`] in
  let proof = match th with
    | None -> None
    | Some th -> Some (PURE_REWRITE_RULE[GSYM IMP_IMP] (add_lp_hyp true th)) in
    add_raw_ineq name std_flag tm proof;;


let mk_lp_ineq =
  let subst_list = [
    `dart (H:(real^3#real^3)hypermap)`, `dart_of_fan (V:real^3->bool,E)`;
    `H:(real^3#real^3)hypermap`, `hypermap_of_fan (V:real^3->bool,E)`;
    `face_map (H:(real^3#real^3)hypermap)`, `f_fan_pair_ext (V:real^3->bool,E)`;
    `azim_mod:real^3#real^3->real`, `azim_dart (V:real^3->bool,E)`;
    `azim2_mod:real^3#real^3->real`, `azim2_fan (V:real^3->bool,E)`;
    `azim3_mod:real^3#real^3->real`, `azim3_fan (V:real^3->bool,E)`;
    `rhazim_mod:real^3#real^3->real`, `rhazim_fan (V:real^3->bool,E)`;
    `rhazim2_mod:real^3#real^3->real`, `rhazim2_fan (V:real^3->bool,E)`;
    `rhazim3_mod:real^3#real^3->real`, `rhazim3_fan (V:real^3->bool,E)`;
    `sol_mod:(real^3#real^3->bool)->real`, `sol_fan (V:real^3->bool,E)`;
    `tau_mod:(real^3#real^3->bool)->real`, `tau_fan (V:real^3->bool,E)`;
    `y1_mod:real^3#real^3->real`, `y1_fan`;
    `y2_mod:real^3#real^3->real`, `y2_fan`;
    `y3_mod:real^3#real^3->real`, `y3_fan (V,E)`;
    `y4_mod:real^3#real^3->real`, `y4_fan (V,E)`;
    `y5_mod:real^3#real^3->real`, `y5_fan (V,E)`;
    `y6_mod:real^3#real^3->real`, `y6_fan`;
    `y7_mod:real^3#real^3->real`, `y7_fan (V,E)`;
    `y8_mod:real^3#real^3->real`, `y8_fan (V,E)`;
    `y9_mod:real^3#real^3->real`, `y9_fan (V,E)`;
    `y4'_mod:real^3#real^3->real`, `y4'_fan (V,E)`;
    `ye_mod:real^3#real^3->real`, `ye_fan`;
    `yn_mod:real^3->real`, `yn_fan`;
  ] in
    fun ineq_th ->
      let proof = PURE_REWRITE_RULE[GSYM IMP_IMP] (add_lp_hyp true ineq_th) in
      let tm0 = concl proof in
      let tm1 = subst subst_list tm0 in
	tm1, proof;;


let add_lp_ineq_th (name, std_flag, ineq_th) =
  let tm1, proof = mk_lp_ineq ineq_th in
    add_raw_ineq name std_flag tm1 (Some proof);;

let gen_raw_ineq_data_str name std_flag str =
  let tm = parse_lp_ineq str ["H", `:(A)hypermap`; "V", `:C->bool`; "node_mod", `:A->C`] in
    {name = name; tm = tm; proof = None; std_only = std_flag};;

let gen_raw_ineq_data name std_flag ineq_th =
  let tm1, proof = mk_lp_ineq ineq_th in
    {name = name; tm = tm1; proof = Some proof; std_only = std_flag};;

(******************************)


(* Generates a general form of an inequality *)
let generate_ineq0 ineq =
  let th0 = Lp_gen_ineqs.generate_ineq ineq.tm in
  let th1 = add_lp_hyp true th0 in
  let proof = 
    match ineq.proof with
      | Some th -> th
      | None -> TRUTH in
    PROVE_HYP proof th1;;


let generate_ineq1 =
  
let power_th = 
prove(`!f (x:A). (f POWER 2) x = f (f x) /\ (f POWER 3) x = f (f (f x))`,
REWRITE_TAC[Hypermap.POWER_2; Hypermap.THREE; Hypermap.POWER; o_THM]) in let mk_list_th th = let tm, proof = mk_lp_ineq th in generate_ineq0 {name = "tmp";
std_only = false; tm = tm; proof = Some proof} in let strip_ALL th = let ltm, set_tm = dest_comb (concl th) in let p_tm = rand ltm in let var_tm, _ = dest_abs p_tm in let th1 = (BETA_RULE o SPEC var_tm o PURE_REWRITE_RULE[GSYM ALL_MEM]) th in let mem_tm = (fst o dest_imp o concl) th1 in UNDISCH th1, (var_tm, mem_tm) in let th_rule = (UNDISCH_ALL o INST_TYPE [`:num`, aty; `:num`, `:D`; `:real^3`, `:C`] o SPEC_ALL o REWRITE_RULE[GSYM IMP_IMP; RIGHT_IMP_FORALL_THM]) in let def_ths = let defs = map (fst o strip_ALL o mk_list_th) [y1_def; y2_def; y3_def; y4_def; y5_def; y6_def; y9_def] in let r = (fst o strip_ALL o add_lp_hyp true o REWRITE_RULE[GSYM IMP_IMP]) in let extra_defs = map r [ y8_list_def; y4'_list_def; y4'_list_f_def; y4'_list_inv_f_def; y4'_list_ff_def; y4'_list_fff_def; ] in extra_defs @ defs in let mem_ths = [ `MEM (d:num#num) (list_of_darts3 L)`, th_rule Lp_gen_theory.list_of_darts3_subset; `MEM (d:num#num) (list_of_darts4 L)`, th_rule Lp_gen_theory.list_of_darts4_subset; `MEM (d:num#num) (list_of_darts5 L)`, th_rule Lp_gen_theory.list_of_darts5_subset; `MEM (d:num#num) (list_of_darts6 L)`, th_rule Lp_gen_theory.list_of_darts6_subset; ] in let rec simplify tm = if is_imp tm then let ltm, q_tm = dest_comb tm in let imp_tm, p_tm = dest_comb ltm in let p_tm, q_tm = dest_imp tm in let p_th = (PURE_REWRITE_CONV def_ths THENC PURE_REWRITE_CONV[power_th]) p_tm in let q_th = simplify q_tm in MK_COMB (AP_TERM imp_tm p_th, q_th) else REFL tm in fun simplify_all_flag ineq -> let th1 = generate_ineq0 ineq in let th2, (var_tm, mem_tm) = strip_ALL th1 in let eq_th = if simplify_all_flag then (PURE_REWRITE_CONV def_ths THENC PURE_REWRITE_CONV[power_th]) (concl th2) else simplify (concl th2) in let th3 = EQ_MP eq_th th2 in let th4 = try let mem_th = assoc mem_tm mem_ths in PROVE_HYP mem_th th3 with Failure _ -> th3 in let th5 = (GEN var_tm o DISCH mem_tm) th4 in PURE_REWRITE_RULE[ALL_MEM] th5;; let generate_ineq = generate_ineq1 false;; let add_lp_list_ineq, find_ineq = let ineq_table = Array.init 8 (fun i -> Hashtbl.create 10) in let std_ineq_table = Array.init 8 (fun i -> Hashtbl.create 10) in let add std_flag i pair_flag name ineq_th = let table = if std_flag then std_ineq_table.(i) else ineq_table.(i) in Hashtbl.add table name {ineq_th = ineq_th; pair_flag = pair_flag} in
let DECIMAL_INT = 
prove(`!n. DECIMAL n 1 = &n`,
REWRITE_TAC[DECIMAL; REAL_DIV_1]) in let pos_thms = let inst = INST[`V:real^3->bool,E:(real^3->bool)->bool`, `fan:(real^3->bool)#((real^3->bool)->bool)`] in map inst Lp_ineqs_def.list_var_pos in let rec rewrite_num_conv real_flag tm = match tm with | Var _ -> REFL tm | Const _ -> REFL tm | Abs (v_tm, b_tm) -> ABS v_tm (rewrite_num_conv real_flag b_tm) | Comb (Const ("NUMERAL", _), _) -> Arith_nat.NUMERAL_TO_NUM_CONV tm | Comb (Const ("real_of_num", _) as ltm, rtm) -> if real_flag then AP_TERM ltm (rewrite_num_conv real_flag rtm) else REFL tm | Comb (Comb (Const ("DECIMAL", _), _), _) -> REFL tm | Comb (ltm, rtm) -> MK_COMB (rewrite_num_conv real_flag ltm, rewrite_num_conv real_flag rtm) in let rewrite_num ineq_th = let th1 = PURE_REWRITE_RULE[DECIMAL_INT] ineq_th in let imp_flag = (is_imp o snd o dest_abs o lhand o concl) th1 in if imp_flag then let th2 = PURE_REWRITE_RULE[IMP_IMP] th1 in let th3 = CONV_RULE ((LAND_CONV o ABS_CONV o RAND_CONV) (rewrite_num_conv true)) th2 in let th4 = CONV_RULE ((LAND_CONV o ABS_CONV o LAND_CONV) (rewrite_num_conv false)) th3 in PURE_REWRITE_RULE[GSYM IMP_IMP] th4 else CONV_RULE (rewrite_num_conv true) th1 in (* add_lp_list_ineq *) (fun (name, std_only, list_ineq_th) -> let approx_ths, neg_approx_ths = Lp_approx_ineqs.generate_ineqs pos_thms [3;4;5;6;7] list_ineq_th in let set_name = (fst o dest_const o rator o rand o concl o hd) approx_ths in let pair_flag, approx_ths, neg_approx_ths = if set_name = "list_dart_pairs" then let r = PURE_REWRITE_RULE[Lp_gen_theory.ALL_list_dart_pairs_split; BETA_THM; FST; SND] in true, map r approx_ths, map r neg_approx_ths else false, approx_ths, neg_approx_ths in let approx_ths = zip (3--7) approx_ths in let neg_approx_ths = if (neg_approx_ths = []) then [] else zip (3--7) neg_approx_ths in let r = rewrite_num in if std_only then let _ = map (fun (i, t) -> add true i pair_flag name (r t)) approx_ths in let _ = map (fun (i, t) -> add true i pair_flag (name^"_neg") (r t)) neg_approx_ths in () else let _ = map (fun (i, t) -> add false i pair_flag name (r t)) approx_ths in let _ = map (fun (i, t) -> add false i pair_flag (name^"_neg") (r t)) neg_approx_ths in ()), (* find_ineq *) (fun std_flag precision name -> if std_flag then try Hashtbl.find std_ineq_table.(precision) name with Not_found -> Hashtbl.find ineq_table.(precision) name else Hashtbl.find ineq_table.(precision) name);;
(* Generates all list inequalities from raw inequalities *) let process_raw_ineqs = let report s = Format.print_string s; Format.print_newline(); Format.print_flush() in let counter = ref 0 and total = ref 0 in let process_ineq ineq = let _ = counter := !counter + 1 in let _ = report (sprintf "Processing: %s (%d / %d)" ineq.name !counter !total) in let th = generate_ineq ineq in add_lp_list_ineq (ineq.name, ineq.std_only, th) in fun () -> let ineqs = raw_ineq_list() in let _ = counter := 0 in let _ = total := length ineqs in let _ = map process_ineq (raw_ineq_list()) in ();; end;;