needs "../formal_lp/hypermap/main/lp_certificate.hl";;
needs "../formal_lp/glpk/lp_binary_certificate.hl";;
needs "../formal_lp/hypermap/computations/informal_computations.hl";;
needs "../glpk/tame_archive/lpproc.ml";;
needs "../glpk/tame_archive/hard_lp.ml";;


module Build_certificates = struct


(* Temporary directory where certificates are computed *)
let tmp_dir = ref (flyspeck_dir ^ "/../formal_lp/glpk/tmp");;
(* Output directory for certificates *)
let output_dir = ref (flyspeck_dir ^ "/../formal_lp/glpk/binary");;
(* Path to LP_HL.exe *)
let lp_hl_dir = ref (flyspeck_dir ^ "/../formal_lp/LP-HL/LP-HL/bin/Release");;
(* Main model file *)
let model2_path = ref (tame_dir ^ "/model2.mod");;



open Glpk_link;;
open Lp_informal_computations;;
open Lpproc;;
open Hard_lp;;
open List;;
open Lp_certificate;;
open Lp_binary_certificate;;


let mem_stat () =
  let stat = Gc.stat() in
  let word = float_of_int (Sys.word_size / 8) in
  let free = float_of_int stat.Gc.free_words *. word /. 1024.0 in
  let total = float_of_int stat.Gc.heap_words *. word /. 1024.0 in
  let allocated = total -. free in
  let str = sprintf "allocated = %f (free = %f; total_size = %f; %f)\n" 
    allocated free total (free /. total) in
    print_string str;;


let run_command dir com =
  let cur_dir = Sys.getcwd() in
  let _ = Sys.chdir dir in
  let result = Sys.command com in
  let _ = Sys.chdir cur_dir in
    result;;


let make_models include_main =
  let _ = make_model() in
  let sed1 = "s/maximize objective:.*/maximize objective: sum{i in node} ln[i];/" and
      sed2 = "s/lnsum_def:.*//" and
      sed3 = "s/main:.*//" in
  let _ = if include_main then () else
    let cmd = sprintf "sed -e '%s' %s > %s" sed3 model !model2_path in
    let cmd2 = sprintf "cp %s %s" !model2_path model in
    let _ = run_command tame_dir cmd in
    let _ = run_command tame_dir cmd2 in
      () in
  let cmds = [sed1; sed2] in
  let str1 = itlist (fun cmd str -> sprintf "-e '%s' %s" cmd str) cmds "" in
  let cmd = sprintf "sed %s %s > %s" str1 model !model2_path in
  let _ = run_command tame_dir cmd in
    ();;


let ampl_of_bb' fname bb =
  let out = open_out fname in
  let _ = ampl_of_bb out bb in
    close_out out;;


let save_string fname str =
  let out = open_out fname in
  let _ = Printf.fprintf out "%s" str in
    close_out out;;


let find_index p list = fst (find (fun _,x -> p x) (Glpk_link.enumerate list));;

let find_face_index fs f =
  let rots = Glpk_link.rotation [f] in
  let rec find fs =
    match fs with
      | [] -> failwith "find_face_index"
      | h :: t ->
	  let eq = fold_right (or) (map ((=) h) rots) false in
	    if eq then 0 else find t + 1 in
    find fs;;


let build_permutation fs0 fs = map (fun f -> find_face_index fs f) fs0;;


let save_info out_dir name infeasible hyp_list bb =
  let f2 = faces bb in
  let perm = build_permutation hyp_list f2 in
  let _ = ampl_of_bb' (sprintf "%s/%s_pars.txt" out_dir name) bb in
  let perm_str = unsplit ", " string_of_int perm in
  let hyp_str = unsplit ";" (unsplit "," string_of_int) hyp_list in
  let p = sprintf in
  let lines = [
    p "name: %s" name;
    p "infeasible: %b" infeasible;
    p "hypermap: %s" hyp_str;
    p "faces: %s" perm_str;
  ] in
    save_string (out_dir ^ "/flyspeck-" ^ name ^ ".txt") (join_lines lines);;


let clean_out_dir =
  let permanent_files = [
    "000.txt";
    "string_archive.txt";
  ] in
    fun out_dir ->
      let files = Array.to_list (Sys.readdir out_dir) in
      let files' = filter (fun name -> not (mem name permanent_files)) files in
      let _ = map (fun name -> Sys.remove (sprintf "%s/%s" out_dir name)) files' in
	();;

(************************************)	

(* Builds a terminal case from the given hypermap *)
let build_terminal_case =
  let create_infeasible_solution out_name =
    let sed1_cmd = sprintf "sed -f %s/../sed_add_slack.sed %s.lp > %s_slack.lp" !tmp_dir out_name out_name and
	sed2_cmd = sprintf "sed -f %s/../sed_build_objective.sed %s.lp > /dev/null" !tmp_dir out_name and
	sed3_cmd = sprintf "sed -f %s/../sed_replace_objective.sed %s_slack.lp > %s_slack2.lp" !tmp_dir out_name out_name in
    let _ = run_command !tmp_dir sed1_cmd and
	_ = run_command !tmp_dir sed2_cmd and
	_ = run_command !tmp_dir sed3_cmd in
    let solve_com2 = sprintf "glpsol --cpxlp %s_slack2.lp -w %s.txt" out_name out_name in
      run_command !tmp_dir solve_com2 in
    
    fun hyp_list infeasible bb ->
      let error_glpsol = sprintf "build_terminal_case: glpsol failed for %s" bb.hypermap_id in
      let _ = clean_out_dir !tmp_dir in
      let out_name = "out" in
	(* Create info and parameter files *)
      let _ = save_info !tmp_dir out_name infeasible hyp_list bb in
      let data_file = sprintf "%s_pars.txt" out_name in
	(* Create .lp file and solve a feasible problem *)
      let solve_com = sprintf "glpsol -m %s -d %s -w %s.txt --wcpxlp %s.lp > /dev/null"  !model2_path data_file out_name out_name in
      let result = run_command !tmp_dir solve_com in
      let _ = result = 0 or failwith error_glpsol in
	(* Solve an infeasible problem *)
      let result = if infeasible then create_infeasible_solution out_name else 0 in
      let _ = result = 0 or failwith error_glpsol in
	(* Run LP-HL.exe *)
      let com = sprintf "mono %s/LP-HL.exe %s" !lp_hl_dir (sprintf "flyspeck-%s.txt" out_name) in
      let result = run_command !tmp_dir com in
      let _ = result = 0 or failwith (sprintf "LP-HL.exe failed for %s" bb.hypermap_id) in
	(* Create a certificate *)
      let terminal = read_binary_terminal (sprintf "%s/%s_out.bin" !tmp_dir out_name) in
	terminal;;
  

(*************************************)

(* If true, then terminal cases are not constructed *)
let test_mode = ref false;;
let print_progress = ref false;;


let reset_build_counters, next_build_case, next_terminal_case, report_build_progress =
  let cases = ref 0 and
      terminals = ref 0 in
    (fun () -> cases := 0; terminals := 0),
  (fun () -> cases := !cases + 1),
  (fun () -> terminals := !terminals + 1),
  (fun () -> if !print_progress then report (sprintf "cases = %d; terminals = %d" !cases !terminals) else ());;
    

let set_face_numerics_info bb = 
  let opp xs = nub (xs @ map (C opposite_edge bb) xs) in
  let edge_of_small = opp (rotation bb.std3_small) in
  let short_edge = opp bb.d_edge_200_225 in
  let long_edge = opp bb.d_edge_225_252 in
  let _ =  (intersect edge_of_small long_edge = []) or failwith "set_face_numerics" in
  let shortadds =  subtract (edge_of_small @ short_edge) bb.d_edge_200_225 in
  let shortfields = (map (fun t-> ("e_200_225",t)) shortadds) in
  let longadds =  subtract long_edge bb.d_edge_225_252 in
  let longfields = (map (fun t-> ("e_225_252",t)) longadds) in
  let r = filter (fun t -> mem t (std_faces bb) & (length t = 3) )
          (nub (map (C face_of_dart bb) long_edge)) in
  let _ = (intersect (rotation bb.std3_small) r =[]) or failwith "set_face_numerics" in
  let bigfields = map (fun t -> ("bt",t)) (subtract r bb.std3_big) in
  let fields = shortfields @ longfields @ bigfields in
    if fields=[] then bb, [] else
      let new_bb = modify_bb bb false fields [] in
	(* Don't need to add symmetric inequalities and short edges for small triangles: *)
	(* it is done at other steps. Add new big triangles only. *)
      let new_big_faces = subtract r bb.std3_big in
      let long_edge_faces = zip (map (C face_of_dart bb) long_edge) long_edge in
      let darts = map (C assoc long_edge_faces) new_big_faces in
      let info_list = map (fun d -> {split_type = "add_big"; split_face = d}) darts in
	new_bb, info_list;;


let set_node_numerics_info bb = 
  if not(card_node bb = 13) then bb, [] else
  let n_high = length bb.node_236_252 in
  let n_mid = length bb.node_218_236 in
  let n_highish = length (highish bb) in
  if (n_high =0 )  & (n_mid +n_highish < 2) then bb, [] else
  let _ = (n_mid * 18 + n_highish * 18 + n_high *36 <= 52) or failwith "set_node_numerics" in
  let node_new_low = subtract (node_list bb) (unions [bb.node_200_218 ;bb.node_218_236; bb.node_236_252;bb.node_218_252]) in
  let vfields_low = map (fun t -> ("200_218",t)) node_new_low in
  let vfields_mid = map(fun t->("218_236",t)) (highish bb) in
  let vfields = vfields_low @ vfields_mid in
    if vfields = [] then bb, [] else
      let new_bb = modify_bb bb false [] vfields in
      let info =
	if n_high > 0 then
	  {split_type = "high"; split_face = bb.node_236_252}
	else
	  {split_type = "mid"; split_face = bb.node_218_236 @ highish bb} in
	new_bb, [info];;


let rec build hard_flag (hyp_list, bb0) =
  let _ = next_build_case() in
  let _ = report_build_progress() in

  let bb, info_list = if not hard_flag then bb0, []
  else
    let bb1, list1 = set_face_numerics_info bb0 in
    let bb2, list2 = set_node_numerics_info bb1 in
      bb2, list1 @ list2 in

  let f_hint = if hard_flag then add_hints_include_flat else (fun t -> ()) in
  let result = solve_f f_hint bb in
  let certificate =
    if not (is_feas result) then
      let _ = next_terminal_case() in
      let infeasible =
	match result.lpvalue with
	  | Lp_infeasible -> true
	  | _ -> false in
	(* terminal case *)
      let terminal = 
	if !test_mode then 
	  empty_terminal
	else 
	  build_terminal_case hyp_list infeasible bb in
	Lp_terminal terminal
    else
      (* split case *)
      let n = length (hd bb.std_faces_not_super) in
	if n = 3 then
	  if hard_flag then
	    split3_hard hyp_list bb
	  else
	    split3 hard_flag hyp_list bb
	else if n = 4 then
	  split4 hard_flag hyp_list bb
	else if n = 5 then
	  split5 hard_flag hyp_list bb
	else if n = 6 then
	  split6 hard_flag hyp_list bb
	else
	  failwith (sprintf "build: incorrect face size - %d" n) in
    itlist (fun info c -> Lp_split (info, [c])) info_list certificate

(* split3 - hard cases *)
and split3_hard hyp_list bb =
  let _ = assert (bb.hints <> []) in
  let bbs, split_type, list = 
    match hd (bb.hints) with
      | Triangle_split d -> switch_std3 d bb, "tri", face_of_dart d bb
      | Edge_split d -> switch_edge d bb, "edge", d
      | High_low i -> switch_node bb i, (if mem i (highish bb) then "236" else "218"), [i] in
  let _ = map clear_hint bbs in
  let case_args = zip [hyp_list; hyp_list] bbs in
  let cases = map (build true) case_args in
  let info = {split_type = split_type; split_face = list} in
    Lp_split (info, cases)

(* split3 *)
and split3 hard_flag hyp_list bb =
  let split_face = hd (std_tri_prebranch bb) in
  let _ = assert (length split_face = 3) in
  let bbs = switch3 bb in
  let case_args = zip [hyp_list; hyp_list] bbs in
  let cases = map (build hard_flag) case_args in
  let info = {split_type = "tri"; split_face = split_face} in
    Lp_split (info, cases)
    
(* split4 *)
and split4 hard_flag hyp_list bb =
  let split_face = hd bb.std_faces_not_super in
  let _ = assert (length split_face = 4) in
  let bbs = switch4 bb in
  let dart13 = nth split_face 1, nth split_face 2 and
      dart24 = nth split_face 0, nth split_face 1 in
  let split13 = split_list hyp_list dart13 and
      split24 = split_list hyp_list dart24 in
  let case_args = zip [split13; split24; split13; split24; hyp_list] bbs in
  let cases = map (build hard_flag) case_args in
  let info = {split_type = "quad"; split_face = split_face} in
    Lp_split (info, cases)

(* split5 *)
and split5 hard_flag hyp_list bb =
  let split_face = hd bb.std_faces_not_super in
  let _ = assert (length split_face = 5) in
  let bbs = switch5 bb in
  let darts = rotateL 1 (list_pairs split_face) in
  let splits_one = map (split_list hyp_list) darts in
  let splits_two = map2 split_list splits_one (rotateL 2 darts) in
  let case_args = zip (hyp_list :: (splits_one @ splits_two)) bbs in
  let cases = map (build hard_flag) case_args in
  let info = {split_type = "pent"; split_face = split_face} in
    Lp_split (info, cases)

(* split6 *)
and split6 hard_flag hyp_list bb =
  let split_face = hd bb.std_faces_not_super in
  let _ = assert (length split_face = 6) in
  let bbs = switch6 bb in
  let darts = Glpk_link.rotateL 1 (list_pairs split_face) in
  let splits = map (split_list hyp_list) darts in
  let case_args = zip (hyp_list :: splits) bbs in
  let cases = map (build hard_flag) case_args in
  let info = {split_type = "hex"; split_face = split_face} in
    Lp_split (info, cases);;


(* Moves all hex faces to std56_flat_free *)
(* This operation can be done on all hypermaps since branching on hexes is not required: *)
(* there are no inequalities for std56_flat_free INTER std6 *)
let modify_hex_cases bb =
  let faces6 = filter (fun f -> length f = 6) bb.std_faces_not_super in
    itlist (fun fc bb -> modify_bb bb true ["flat_free", fc] []) faces6 bb;;


(* Builds an lp certificate *)
let build_certificate modify_hex bb =
  let _ = reset_build_counters() in
  let bb = if modify_hex then modify_hex_cases bb else bb in
  let hyp_list = (snd o convert_to_list) bb.string_rep in
  let hard_flag = bb.hints <> [] in
  let root = build hard_flag (hyp_list, bb) in
  let _ = report_build_progress() in
  let certificate = {hypermap_string = bb.string_rep; root_case = root} in
    certificate;;
 



(************************)

(* Builds and saves certificates for all given cases *)
let build_and_save_all =
  let counter = ref 0 and
      total = ref 0 in
  let buf = ref [] and
      file_counter = ref 0 and
      buf_terminals = ref 0 in

  let save_buf name_prefix =
    if length !buf > 0 then
      let _ = file_counter := !file_counter + 1 in
      let fname = sprintf "%s/%s_%d.dat" !output_dir name_prefix !file_counter in
      let _ = write_lp_certificates fname !buf in
	buf := [];
	buf_terminals := 0
    else
      () in

  let process name_prefix max modify_hex bb =
    let _ = counter := !counter + 1 in
    let _ = report (sprintf "%d/%d" !counter !total) in
    let certificate = build_certificate modify_hex bb in
    let _ = buf := certificate :: !buf in
    let _ = buf_terminals := !buf_terminals + count_terminals certificate.root_case in
      if !buf_terminals >= max then
	save_buf name_prefix
      else
	() in

    fun name_prefix max modify_hex bbs ->
      let _ = total := length bbs and
	  _ = counter := 0 and
	  _ = file_counter := 0 and
	  _ = buf := [] and
	  _ = buf_terminals := 0 in
      let _ = map (process name_prefix max modify_hex) bbs in
	save_buf name_prefix;;


end;;