flyspeck_needs "nonlinear/prep.hl";;
flyspeck_needs "nonlinear/scripts.hl";;

module Break_case_exec = struct

(* open Sphere2;; *)


(* PRELIMINARIES *)

let ( MSEC_INC ) = 1000;; (* global parameter for msec target size *)
let ERROR_TOLERANCE = 0.3;;  (* try for times within MSEC_INC * (1 +/- ERROR_TOLERANCE ). *)
let verbose = true;;

(*

The type iargs is used to record a partition of a nonlinear inequality domain into pieces.
The code in this file is used to make the pieces all take approximately MSEC_INC milliseconds
to run.  The field 'a records the interval arithmetic C++ runtime in milliseconds.

The partition is determined from the data and the original domain determined by (x,z), where
x and z are lists of real numbers giving the lower left and upper right hand corners of the domain.

Iarg_leaf represents a domain with no subpartition.  The domain is all of (x,z).

Iarg_bisect (i,left,right) is a cut exactly in the middle along the ith variable.
The fields left and right give the further partition of the left and right sides.
For example if i=1 and (x,z) = [1.0;1.0;1.0],[3.0;3.0;3.0] then the partition goes
into (xleft,zleft) = [1.0;1.0;1.0],[3.0;2.0;3.0] and (xright,zright)=[1.0;2.0;1.0],[3.0;3.0;3.0].

Iarg_facet ((i,side),frac,msec,residual) breaks the domain in two unequal parts along the ith coordinate 
specified by a fraction 0.0 <= frac <= 1.0.
The fraction is measured from the left if side=false, and from the right otherwise.

For example, if i,side=1,false frac=0.25 and (x,z)= [x0;x1;x2],[z0;z1;z2], then
(xleft,rleft)= x,[z0;0.75 * x1+0.25 *z1;z2] and (xright,zright)=[x0;0.75*x1+0.25*z1,z].
The number msec is the C++ runtime for (xleft,rleft), and residual is the recursively defined
partition on (xright,zright).

For example, if we change side=true and keep the rest of the data the same, then
(xleft,rleft)= x,[z0;0.25 * x1+0.75 *z1;z2] and (xright,zright)=[x0;0.25*x1+0.75*z1,z].
The number msec is now the C++ runtime for (xright,right), and residual is the recursively defined
partition on (xleft,zleft). So changing the side exchanges right and left all the way along.


*)


type ('a) iargs = Iarg_facet of ((int * bool) * float * 'a * ('a) iargs)
	     | Iarg_leaf of 'a
	     | Iarg_bisect of (int * ('a) iargs * ('a) iargs);;

let rec string_of_iargs = function 
  | Iarg_leaf msec -> Printf.sprintf "\n Iarg_leaf %d" msec
  | Iarg_facet ((i,b),frac,msec,a) -> Printf.sprintf "\n Iarg_facet ((%d,%s),%3.4f,%d,%s)"
      i (string_of_bool b) frac msec (string_of_iargs a)
  | Iarg_bisect(idx,a,b) -> Printf.sprintf "\n Iarg_bisect (%d,%s,%s)" 
      idx (string_of_iargs a) (string_of_iargs b);;

let output_case s iargs = 
  let sfull = Printf.sprintf "\nadd_case (\"%s\",%s);;\n" s (string_of_iargs iargs) in
  let oc = open_out_gen [Open_append;Open_text] 436 (fullpath "nonlinear/break_case_log_more.hl") in
  (Pervasives.output_string oc (sfull); close_out oc);;

let sprintf = Printf.sprintf;;

Random.init 0;;

let getprep s = hd(filter (fun t -> t.idv = s) (!Prep.prep_ineqs));;

(* omit the quad cases for now: *)

let idvlist = 
  let nonquad = filter (fun t -> not(Optimize.is_quad_cluster t.tags))  (!Prep.prep_ineqs) in
    map (fun t -> t.idv) nonquad;;

let nth = List.nth;;

let rec cart a b = 
  match a with
    | [] -> []
    | a::rest -> (map (fun x -> (a,x)) b) @ cart rest b;;

let maxlist xs = List.fold_right max xs (nth xs 0);;

let rec trim s = 
  let white c = mem c [' '; '\012'; '\n'; '\r'; '\t'] in
  let n = String.length s in
  let subs k = String.sub s k (n-1) in
    if (n > 0 && white (s.[0])) then trim (subs 1)
    else if (n > 1 && white (s.[n-1])) then trim (subs 0)
    else s;;

let msec_inc = float_of_int MSEC_INC;;

let float_cache = ref (fun() -> 0.0);;

let eval_float s = 
    let (b,r) = Flyspeck_lib.eval_command ~silent:false 
      ("float_cache := (fun () -> ("^s^"));;") in
    let _ = b or (print_string (r^"\n"^s^"\n"); failwith "bad input string") in
    let t= (!float_cache)() in
      t;;

(* C++ CODE GENERATION. *)

let cpp_template_arg = sprintf "
 const char svn[] = %s;
 const char ineq_id[] = %s;

 int testRun(double x1[6],double z1[6]) // autogenerated code
	{
        // Warning: not rigorous. The rounding is off by epsilon. Use this only for experiments.
	interval tx[6]={interval(x1[0],x1[0]),interval(x1[1],x1[1]),interval(x1[2],x1[2]),
                        interval(x1[3],x1[3]),interval(x1[4],x1[4]),interval(x1[5],x1[5])   };
	interval tz[6]={interval(z1[0],z1[0]),interval(z1[1],z1[1]),interval(z1[2],z1[2]),
                        interval(z1[3],z1[3]),interval(z1[4],z1[4]),interval(z1[5],z1[5])}; 
	domain x = domain::lowerD(tx);
	domain z = domain::upperD(tz);
        domain x0=x;
        domain z0=z;
        %s
        const Function* I[%d] = {%s}; // len ...
        cellOption opt;
        opt.allowSharp = %d; // sharp
        opt.onlyCheckDeriv1Negative = %d; // checkderiv
        %s // other options.
	return  prove::recursiveVerifier(0,x,z,x0,z0,I,%d,opt); // len
	}";;

let mk_cpp_arg_proc t s tags = 
  let sharp = if  mem Sharp tags then 1 else 0 in
  let checkderiv = if  mem Onlycheckderiv1negative tags then 1 else 0 in
  let ifd b s = if b then s else "" in
  let (b,f) = Optimize.widthCut tags in
  let sWidth = ifd b (sprintf "\topt.widthCutoff = %8.16f;\n" f) in 
  let c = map Optimize.cpp_string_of_term in
  let f (x,y,z) = (c x,c y,c z) in
  let (aas,bbs,iis) = f (Optimize.dest_nonlin t) in
  let len = length iis in
  let sq = Optimize.quoted s in
  let svn = (Optimize.quoted(Optimize.svn_version())) in
    cpp_template_arg svn sq (Optimize.cpp_template_t "" iis) 
      len (Optimize.cpp_template_Fc "" len) sharp  checkderiv sWidth len;;

let mkfile_arg =
  let cpp_tail = Optimize.join_lines (Optimize.load_file  (flyspeck_dir^"/../interval_code/arg_tail.txt")) in
  let cpp_header = Optimize.cpp_header() in
    fun t s tags  ->
	Flyspeck_lib.output_filestring Optimize.tmpfile
	  (Optimize.join_lines [cpp_header;Auto_lib.interval_code;(mk_cpp_arg_proc t s tags);cpp_tail]);;

let execute_args ex tags s testineq xlist zlist =  
   let x = List.nth xlist in
   let z = List.nth zlist in
   let args = sprintf " %f %f %f %f %f %f   %f %f %f %f %f %f"
      (x 0) (x 1) (x 2) (x 3) (x 4) (x 5)     (z 0) (z 1) (z 2) (z 3) (z 4) (z 5) in
  let interval_dir = flyspeck_dir^"/../interval_code" in
  let _ = mkfile_arg testineq s tags in
  let _ = Optimize.compile_cpp() in 
  let _ = (not ex) or (0=  Sys.command(interval_dir^"/test_auto"^args)) or failwith "interval execution error" in
    ();;

let process_and_prep_args ex (s,tags,case) = 
  let _ = report ("process and prep args: "^s) in
  let (s,tags,testineq) = (* preprocess debug *) (s,tags,case) in
  let (x,y,_) = Optimize.dest_nonlin testineq in
    (execute_args ex tags s testineq , x, y);;

let rerun_timer =
  let run_out = Filename.temp_file "run" ".out" in
    fun xlist zlist timer ->
      let x = List.nth xlist in
      let z = List.nth zlist in
      let args = sprintf " %f %f %f %f %f %f   %f %f %f %f %f %f %f"
	(x 0) (x 1) (x 2) (x 3) (x 4) (x 5)     (z 0) (z 1) (z 2) (z 3) (z 4) (z 5) (float_of_int timer) in
      let interval_dir = flyspeck_dir^"/../interval_code" in
      let _ = (0=  Sys.command(interval_dir^"/test_auto"^args^" | tee "^run_out)) or 
	failwith "interval execution error" in
      let outs = trim (process_to_string ("grep msecs "^run_out^" | sed 's/^.*msecs=//' | sed 's/;.*$//' ")) in
	(*  let _ = report (":"^outs^":") in *)
      let msecs = try (int_of_string (outs)) with _ -> timer in
	(msecs);;

(* DOMAIN CONSTANTS *)

let scrub_c t = 
  let th = REWRITE_CONV [Sphere.h0;
        GSYM Nonlinear_lemma.sol0_over_pi_EQ_const1;
	ASSUME `hminus = #1.2317544220903216`;
        ASSUME `pi = #3.1415926535897932385`;
        ASSUME `sol0 = #0.55128559843253080794`] t in
    rhs(concl th);;

let get_constants_xy s = 
  let (s,tags,case) = Optimize.idq_fields (getprep s) in
  let (x,y,_) = Optimize.dest_nonlin case in
    (map scrub_c x,map scrub_c y);;

let get_constants s = 
  let (x,y) = get_constants_xy s in x @ y;;

let get_float_domain s = 
  let m = map (eval_float o Parse_ineq.ocaml_string_of_term) in
  let (x,y) = get_constants_xy s in
    (m x,m y);;

(* they all check out on 2013/7/31
find (fun s -> not(can get_float_domain s)) (snd (chop_list 650 idvlist));;
*)


(* interpolate along edge *)

let index_max_width (x ,z) avoid = 
  let w = map (fun (xi,zi) -> zi -. xi) (zip x z) in
  let avoid_filter = map (fun i -> (mem i avoid)) (0--(List.length x - 1)) in
  let w' = map (fun (t,b) -> if b then 0.0 else t) (zip w avoid_filter) in
  let wm = maxlist w' in
    index wm w';;

let facet_opp (i,side)= (i,not side);;

let invert_side (_,side) (x,z) = 
  if side then (z,x) else (x,z);;

let invert_interpolate (i,side) frac (x,z) =
  let _ = (0.0 <= frac && frac <= 1.0) or failwith "invert_interpolate: frac out of range " in
  let _ = List.length x = List.length z or failwith "invert_interpolate: length mismatch" in
  let n = List.length x in
  let rg = (0--(n-1)) in
  let (x,z) = invert_side (i,side) (x,z) in
  let c = nth x i *. (1.0-. frac) +. nth z i *. frac in
  let modi (xi,j) = if (i=j) then c else xi in
  let xm = map modi (zip x rg) in
  let zm = map modi (zip z rg) in
    (* partition will be x--zm,  xm--z *)
    invert_side(i,side) (xm,zm);;

let interpolate_frac (f1,t1) (f2,t2) = 
  let _ = (0.0 <= f1 && f1 < f2 && f2 <= 1.0 && t1 <= msec_inc && msec_inc <= t2) 
    or failwith "interpolate: out of range" in
    (msec_inc -. t1) *. (f2 -. f1) /. (t2 -. t1) +. f1;;

let run_frac frac facet (x1,z1) timeout = 
  let (xm,zm) = invert_interpolate facet frac (x1,z1) in
  let m3 = if (not(snd facet)) then rerun_timer x1 zm timeout else rerun_timer xm z1 timeout in
    ((xm,zm),m3);;

let rec recursive_find_frac facet x1 z1  (f1,t1) (f2,t2) = 
  (* last case in a batch may have a small t2, so dont use abs_float in the following line *)
  if (t2/. msec_inc -. 1.0) < ERROR_TOLERANCE then (f2,int_of_float t2)
  else if abs_float(t1/. msec_inc -. 1.0) < ERROR_TOLERANCE then (f1,int_of_float t1)
  else       

    let minwidth = 0.001 in
      if (f2 <= f1 +. minwidth && t2 < 1.99 *. msec_inc) then (f2,int_of_float t2)
      else if (f2 <= f1 +. minwidth && t1 > 0.25 *. msec_inc) then (f1,int_of_float t1)
      else 
	let _ = f1+. minwidth <= f2 or failwith "find_frac width underflow" in

	let f3 = interpolate_frac (f1,t1) (f2,t2) in
	let f3 = if (Random.int 4 = 0 && t2 > 1.9*. msec_inc) then 
	  (f1+.f2)/. 2.0 +. Random.float (f2 -. f1) /. 2.0 else f3 in
	let f3 = if (Random.int 3 = 0 && t1 < 0.5*. msec_inc) then
	  f1 +. Random.float (f2 -. f1)/. 2.0 else f3 in
	let _ = (f1 <= f3 && f3 <= f2) or failwith "recursive_find_frac: out of range" in
	let ((xm,zm),m3) = run_frac f3 facet (x1,z1) (2 * MSEC_INC) in
	let _ = if verbose then report (sprintf "recursing fracs %3.3f,%3.3f,%3.3f msecs %d,%d,%d" f1 f3 f2 (int_of_float t1) m3 (int_of_float t2)) in

	let t3 = float_of_int m3 in
	  if abs_float(t3/. msec_inc -. 1.0) < ERROR_TOLERANCE then (f3,m3)
	  else 
	    let ((f1',t1'),(f2',t2')) = if (t3 <= msec_inc) then ((f3,t3),(f2,t2)) else ((f1,t1),(f3,t3)) in
	      recursive_find_frac facet x1 z1 (f1',t1') (f2',t2');;

(* DEBUG *)

let rec pass_time = function
    Iarg_bisect (_,a,b) -> pass_time a + pass_time b
  | Iarg_leaf msec -> msec
  | Iarg_facet (_,_,msec,c) -> msec + pass_time c;;

let rec recheck_pass  timeout domain iarg = match iarg with
  | Iarg_leaf msec -> 
      let (x,z) = domain in
      let msec' = rerun_timer x z timeout in
	Iarg_leaf (msec,msec')
  | Iarg_facet(facet,frac,msec,b) ->
      let (x,z) = domain in
      let (xm,zm) = invert_interpolate facet frac domain in
      let residual = if (not(snd facet)) then (xm,z) else (x,zm) in
      let (_,msec') = run_frac frac facet domain timeout in
	Iarg_facet(facet,frac,(msec,msec'),  recheck_pass timeout residual b)
  | Iarg_bisect(idx,a,b) ->
      let (x,z) = domain in 
      let facet = (idx,false) in
      let (xm,zm) = invert_interpolate facet 0.5 domain in
      let partA,partB = ((x,zm),(xm,z)) in 
	Iarg_bisect(idx,recheck_pass timeout partA a,recheck_pass timeout partB b);;


(* MAIN PROCEDURES *)

let initial_msec = 
  let ft_sec = Scripts.finalize Scripts.finished_times in
  let ft_msec = Scripts.finalize Scripts.finished_times_msecs in
    fun s ->
      try assoc s ft_msec 
      with _ -> 1000 * (assoc s ft_sec );;

let initialize s = 
  let msec = initial_msec s in
  let (x1,z1)=get_float_domain s in
  let stc = Optimize.idq_fields (getprep s) in
  let (compile,_,_) = process_and_prep_args false stc in
  let _ = compile x1 z1 in
    (x1,z1,msec);;

let rec pass_revised avoid ((x,z),msec) =
  let left = false in
  let idx = index_max_width (x,z) avoid in
  let fi = float_of_int in
  let timeout = 2 * MSEC_INC in
  let (s0,s1,s2) = (0,1,2) in
  let facet = (idx,left) in
  let sz m = 
    let t = fi m /. fi MSEC_INC in
      if t < 1.0 -. ERROR_TOLERANCE then s0 else if t < 1.999 then s1 else s2 in
    if msec < timeout  then 
      let _ = if verbose then report "FOUND leaf..." in
      (Iarg_leaf msec)
    else 
      let (_,m3A) = run_frac 0.5 facet (x,z) timeout in
      let (_,m3B) = run_frac 0.5 (facet_opp facet) (x,z) timeout in
      let (m3min,facet,swap) = if m3A<=m3B then (m3A,facet,false) else (m3B,facet_opp facet,true) in
      let smin = sz m3min in
	if (smin=s2) then 
	  let (xm,zm) = invert_interpolate (idx,left) 0.5 (x,z) in
	  let partA,partB = ((x,zm),(xm,z)) in
	  let _ = if verbose then report "FOUND bisection..." in
	    Iarg_bisect (idx,pass_revised [] (partA,m3A),pass_revised [] (partB,m3B))
	else if (smin = s1) then 
	  let (xm,zm) = invert_interpolate facet 0.5 (x,z) in
	  let residual= if (snd facet=left) then (xm,z) else (x,zm) in
	  let m3max = if swap then m3A else m3B in
	  let _ = if verbose then report "FOUND 0.5 facet..." in
	    Iarg_facet(facet,0.5,m3min,pass_revised [] (residual,m3max))
	else (* smin=s0 *) 
	  try (
	    let (fracC,mC) = recursive_find_frac facet x z (0.5,fi m3min) (1.0,fi msec) in
	    let (xm,zm) = invert_interpolate facet fracC (x,z) in
	    let residual= if (snd facet=left) then (xm,z) else (x,zm) in
	    let mD = rerun_timer (fst residual) (snd residual) timeout in
	    let _ = if verbose then report (sprintf "FOUND %3.3f,%d facet..." fracC mC) in
	      Iarg_facet (facet,fracC,mC,pass_revised [] (residual,mD)))
          with Failure fail ->
	    let _ = report ("failure "^fail) in
	      pass_revised (idx::avoid) ((x,z),msec);;

let run_one save s = 
  let (x1,z1,msec1) = initialize s in
  let iarg = (pass_revised [] ((x1,z1),msec1)) in
  let _  = if (save) then output_case s iarg in
    iarg;;

let run_all_over2sec save = 
  let preptimes = map (fun s -> try s,initial_msec s with _ -> failwith s) (idvlist) in
  let over2sec =  (sort (fun (_,s) (_,t) -> s < t) (filter (fun (s,t) -> t > 2000) preptimes)) in
  let over2 = map fst over2sec in
  let alldata = map (fun s -> try s,run_one save s with Failure msg -> report s; failwith msg) over2 in
    alldata;;



end;;