(* ========================================================================= *)
(* FLYSPECK - BOOK FORMALIZATION                                             *)
(*                                                                           *)
(* Chapter: nonlinear inequalities                                           *)
(* Author:  Thomas Hales      *)
(* Date: 2012-06-02                                                          *)
(* ========================================================================= *)

(* 
   Generates
     ocaml function module (Sphere2).
     C++ Functions from HOL-Light specs.
     C++ Interval code for inequalities from HOL-Light specs.

   It uses many lemmas from functional_equation.hl

*)


(* to fix:
   proj_y1 = sqrt_x1 ,etc.

   to reorganize:
   move optimize.hl material on C++ generation here.
   Leave optimize.hl as the preprocessing module.
*)


flyspeck_needs "general/flyspeck_lib.hl";;
flyspeck_needs  "nonlinear/functional_equation.hl";;
flyspeck_needs "nonlinear/optimize.hl";;
flyspeck_needs  "nonlinear/parse_ineq.hl";;
flyspeck_needs "nonlinear/function_list.hl";;

module Auto_lib = struct

  let join_comma = Flyspeck_lib.join_comma;;
  let join_space = Flyspeck_lib.join_space;;
  let join_lines = Flyspeck_lib.join_lines;;
  let functions = Function_list.functions();;

let ocaml_code = 
  let strip_all b = snd(strip_forall (concl (Nonlinear_lemma.strip_let b))) in
  let ocam f = Parse_ineq.ocaml_function (strip_all f) in
  let header = 
  "(* code automatically generated from Parse_ineq.ocaml_code *)\n\n"^
   "module Sphere2 = struct\n\n"^
   "let sqrt = Pervasives.sqrt;;\n\n" ^
   "let safesqrt = Pervasives.sqrt;;\n\n" ^
   "let cos = Pervasives.cos;;\n\n" ^
   "let sin = Pervasives.sin;;\n\n" ^
   "let log = Pervasives.log;;\n\n" ^
   "let asn = Pervasives.asin;;\n\n" ^
   "let atn = Pervasives.atan;;\n\n" ^
    "let hminus = 1.2317544220903043901;;\n\n" ^
    "let pow2 x = x ** (2.0);;\n\n" in
  let tailer =    "end;;\n" in
    header ^
   (Flyspeck_lib.join_lines (map ocam functions)) ^ tailer;;


(* Load module Sphere2 *)

let sphere2_ml = Filename.temp_file "sphere2" ".ml";; 

Flyspeck_lib.output_filestring sphere2_ml ocaml_code;;

loadt sphere2_ml;;


let break_functional_lemma thm = 
  let strip_all b = snd(strip_forall (concl (Nonlinear_lemma.strip_let b))) in
  let (h,ts) = strip_comb (strip_all thm) in
  let isdomain h = (fst(dest_const h) = "domain6") in
  let namebody = if isdomain h then tl ts else ts in
   (List.nth namebody 0),List.nth namebody 1;;

let break_term x = break_functional_lemma (ASSUME x);;

let rec real_arity ty = 
  let real_ty = `:real` in
  if (is_vartype ty) then 0
  else if (ty = real_ty) then 1 else
      let (a,b) = dest_type ty in
	if not(a = "fun" && hd b = real_ty && List.length b = 2) then 0
	else 1 + real_arity (hd(tl b));;

let mk_testing_string thm = 
  let native = Optimize.native_fun in
  let (name,body') = break_functional_lemma thm in
  let name' = fst (strip_comb name) in
  let name' = fst (dest_const name') in
  let cname = Lib.assocd name' native name' in
  let _ = not(cname="NOT_IMPLEMENTED") or failwith "mk_testing:excluded" in
  let evalname = 
    let s = Printf.sprintf "Sphere2.%s 6.36 4.2 4.3 4.4 4.5 4.6" name' in
    let (b,s')=  Flyspeck_lib.eval_command s in
    let _ = b or failwith ("evalname: "^name') in
    let split = Str.split (Str.regexp "[ \n]") in
    let r = hd (List.rev  (split s')) in
    let f = float_of_string r in
    let _ = not(0=Pervasives.compare nan f) or failwith "nan" in
      f in
  let mk_string = 
    Printf.sprintf "  epsValue(\"%s\",%s,%12.12f);" name' cname evalname in
    mk_string;;

let mk_n_testing_string thm = 
  let native = Optimize.native_fun in
  let (name,body') = break_functional_lemma thm in
  let name' = fst (strip_comb name) in
  let (name',ty) = dest_const name' in
  let rarity = real_arity ty in
  let nargs = rarity - 7 in 
  let _ = (rarity >= 8) or failwith "mk_n_testing_string" in 
  let args = map (fun i -> 0.04 +. (float_of_int i)/. 10.0) (1--nargs) in 
  let os = join_space (map (Printf.sprintf "%f") args) in
      (* was interval::interval(".."). Changed 2013/08/14. *)
  let sargs = join_comma (map (Printf.sprintf "interval(\"%f\")") args) in
  let cname = Lib.assocd name' native name' in
  let _ = not(cname="NOT_IMPLEMENTED") or failwith "mk_testing:excluded" in
  let evalname = 
    let s = Printf.sprintf "Sphere2.%s %s 6.36 4.2 4.3 4.4 4.5 4.6" name' os in
    let (b,s')=  Flyspeck_lib.eval_command s in
    let _ = b or failwith ("evalname: "^name') in
    let split = Str.split (Str.regexp "[ \n]") in
    let r = hd (List.rev  (split s')) in
    let f = float_of_string r in
    let _ = not(0=Pervasives.compare nan f) or failwith "nan" in
      f in
  let mk_string = 
    Printf.sprintf "  epsValue(\"%s\",%s(%s),%12.12f);" name' cname sargs evalname in
    mk_string;;

let all_testing_string = 
  let can_test = filter (can  mk_testing_string) functions in
  let can_test_n = filter (can mk_n_testing_string) functions in
    Flyspeck_lib.join_lines (
      (map mk_testing_string can_test) @ (map mk_n_testing_string can_test_n));;

let testing_code = 
  Printf.sprintf 
   "\nvoid selfTest() { 
       cout << \" -- loading test_auto test\" << endl << flush;\n 
       %s 
       cout << \" -- done loading test_auto test\" << endl << flush; }\n" 
    all_testing_string;;

let not_tested = filter (fun t -> 
			  not (can mk_testing_string t) &&
			not (can mk_n_testing_string t)) functions;;

(* following is copied and adapted from optimize.hl *)

let paren s = "("^s^")";;
let i_mk = Optimize.i_mk;;
let string_of_num' = Optimize.string_of_num';;
let dest_decimal = Optimize.dest_decimal;;

let real_ty = `:real`;;

let f1_ty = `:real->real`;;

let f6_ty = `:real->real->real->real->real->real->real` ;;


let f7_ty = `:real->real->real->real->real->real->real->real` ;;
let f8_ty = `:real->real->real->real->real->real->real->real->real` ;;
let f9_ty = `:real->real->real->real->real->real->real->real->real->real` ;;


real_arity `:real->real->real->real->real` =  5;;

let f6to6_ty = `:(real->real->real->real->real->real->real) ->
  (real->real->real->real->real->real->real)`;;

let infix6_ty = `:(real->real->real->real->real->real->real) ->
  (real->real->real->real->real->real->real) ->
  (real->real->real->real->real->real->real)` ;;

let scalar6_ty = `:(real->real->real->real->real->real->real) ->
  (real) ->
  (real->real->real->real->real->real->real)` ;;

let tyvar_inst = 
  let realty = `:real` in
  let u = 
    setify(List.flatten (map (type_vars_in_term o concl) functions)) in
    map (fun t-> (realty,t)) u;;

type_of (inst [(`:real`,`:A`)] `x:A`);;

let nonnative_functional_terms = 
  let f = map ((inst tyvar_inst)o concl) functions in
  let native = map fst Optimize.native_fun in 
  let name t =   fst(dest_const(fst(strip_comb(fst(break_term t))))) in
  let m t = not (mem (name t) native) in 
    filter m f ;;

let real_types = setify(map (type_of o fst o strip_comb o fst o break_term) 
			  nonnative_functional_terms);;

List.length real_types;;

let terms_with_type ty = 
  filter (fun t -> ty = type_of(fst(strip_comb(
    fst(break_term t))))) nonnative_functional_terms;;

let terms_with_real_arity_ge8 = 
  filter (fun t -> 8 <= real_arity (type_of(fst(strip_comb(
    fst(break_term t)))))) nonnative_functional_terms;;

let f0_terms =  (terms_with_type real_ty);;

let f0_code = 
  let f0_template = Printf.sprintf
    "static const interval %s (\"%20.20f\");" in
  let f0_mk thm = 
    let (name,body') = break_functional_lemma thm in
    let name' = fst (strip_comb name) in
    let name' = fst (dest_const name') in
    let s = Printf.sprintf "Sphere2.%s" name' in
    let (b,s')=  Flyspeck_lib.eval_command s in
    let _ = b or failwith ("evalname: "^name') in
    let split = Str.split (Str.regexp "[ \n]") in
    let r = hd (List.rev  (split s')) in
    let warn = "// Warning: "^name'^" computed by floating point\n" in
    let r' =   float_of_string r in
      warn^(f0_template name' r') in
    Flyspeck_lib.join_lines (map (f0_mk o ASSUME) f0_terms);;

let native_fun = Optimize.native_fun;;

let native_infix = [
  ("add6","+");
  ("mul6","*");
  ("sub6","-");
  ("div6","/");
  ("scalar6","*");
];;

let native_interval = [
  ("hminus","hminus")
];;

let f0_name   = 
  let f0_auto = map (fst o dest_const o fst o break_term) f0_terms in
  fun s ->
  if (mem s f0_auto) then s
      else 
	try (Lib.assoc s native_interval) with 
	    Failure _ -> failwith (s^" find: real_name") ;;

let fun_name = 
  let fun_auto = map (fst o dest_const o fst o strip_comb o fst o break_term)
    nonnative_functional_terms in
    fun s->
      try (Lib.assoc s native_fun) with 
	  Failure _ -> 
	    if (mem s fun_auto) then s else failwith ("fun_name not found: "^s);;

let is_comma = 
  let c = "," in
    fun t ->
      let (t,_) = strip_comb t in
	(is_const t && fst (dest_const t) = c);;

let cpp_string_of_term = 
  let rec soh t = 
    if is_var t then fst (dest_var t) else
      let (f,xs) = strip_comb t in
      let ifix i = let [a;b] = xs in paren(soh a ^ " " ^ i ^ " " ^ soh b) in
      let (fv,ty) = 
	if is_var f
	then 
	  let (fv,ty) = (dest_var f) in
	  let _ = warn true ("variable function name: "^fv) in
	    (fv,ty)
	else if (is_const f) then (dest_const f)
	else
	  failwith ("var/const expected:" ^ string_of_term f) in
	match fv with
	  | "real_add" -> ifix "+"
	  | "real_mul" -> ifix "*"
	  | "real_div" -> ifix "/"
	  | "real_sub" -> ifix "-"
	  | "," -> ifix ","
	  | "\\/" -> ifix "\\/"
	  | "real_neg" -> let [a] = xs in "(-" ^ soh a ^ ")"
	  | "real_of_num" -> let [a] = xs in i_mk(soh a)  
	  | "NUMERAL" -> let [_] = xs in string_of_num' (dest_numeral t)
	  | "<" -> let [a;b] = xs in paren(soh a ^ " < " ^ soh b)
	  | ">" -> let [a;b] = xs in paren(soh a ^ " > " ^ soh b)
	  | "+" -> let [a;b] = xs in paren(soh a ^ " + " ^ soh b)
	  | "*" -> let [a;b] = xs in paren(soh a ^ " * " ^ soh b)
	  | "DECIMAL" ->  i_mk(string_of_num' (dest_decimal t))
	  | _ -> 
	      if (ty = real_ty) 
	      then paren(f0_name fv)
	      else if (ty= infix6_ty) or (ty=scalar6_ty)
	      then
		let op = 
		  (try Lib.assoc fv native_infix
		   with Failure _ -> failwith ("parse infix6 "^fv)) in
		  ifix op
	      else 
		(let name = fun_name fv in
		   if (xs=[]) then paren name else
		     let p = if (List.length xs = 1 && is_comma (hd xs))
		     then I else paren in
		     let args = p (join_comma (map soh xs)) in
		       paren (name^args)) in
    fun t -> 
      try (soh t) 
      with Failure s -> failwith (s^" .......   "^string_of_term t);;

(* make functions of 6 variables *)


let f6_code = 
  let f6_template = Printf.sprintf
    "static const Function %s = %s;\n" in
  let f6_terms = (terms_with_type f6_ty) in
  let f6_auto = 
    let b = (fst o dest_const o fst o strip_comb o fst o break_term) in
    let nat = map fst native_fun in
      filter (fun t -> not (mem (b t) nat)) f6_terms in
  let f6_mk tt = 
    let (name1,body') = break_term tt in
    let name' = fst (strip_comb name1) in
    let name' = fst (dest_const name') in
      f6_template name' (cpp_string_of_term body') in
    join_lines (map f6_mk f6_auto);;

let fn_code = 
  let fn_template = Printf.sprintf
    "static const Function %s(%s) { return (%s); }\n" in
  let fn_arg_template = Printf.sprintf
    "const interval& %s" in
  let fn_terms = terms_with_real_arity_ge8 in  (* (terms_with_type f7_ty)@
    (terms_with_type f8_ty) @ (terms_with_type f9_ty) in *)
  let fn_auto = 
    let b = (fst o dest_const o fst o strip_comb o fst o break_term) in
    let nat = map fst native_fun in
      filter (fun t -> not (mem (b t) nat)) fn_terms in
  let fn_mk tt = 
    let (name1,body') = break_term tt in
    let (name',args) =  (strip_comb name1) in
    let ags = join_comma (map (fn_arg_template o fst o dest_var) args) in
    let name' = fst (dest_const name') in
      fn_template name' ags (cpp_string_of_term body') in
    join_lines (map fn_mk fn_auto);;

(* make 6 to 6 *)

let f6to6_template = Printf.sprintf
  "static const Function %s(const Function& %s) {
     return %s;
  }\n";;

let f6to6_terms = terms_with_type f6to6_ty;;

let f6to6_mk tt =
  let (name1,body') = break_term tt in
  let (name',param) = strip_comb name1 in
  let _ = List.length param = 1 or 
    failwith ("one parameter expected "^ string_of_term tt) in
  let name' = fst (dest_const name') in
  let param = fst (dest_var (hd param)) in
    f6to6_template name' param (cpp_string_of_term body');;  

let f6to6_code = join_lines (map f6to6_mk f6to6_terms);;

let tmpfile = flyspeck_dir^"/../interval_code/test_auto.cc";;

let interval_code =
  f0_code  ^ f6to6_code ^ f6_code ^ fn_code^
  testing_code;;

(* based on optimize, but is enhanced with autogenerated code interval_code.  *)

let mkfile_code t s tags  = 
  let cpp_header = Optimize.cpp_header() in
  let cpp_tail = Optimize.cpp_tail() in
  let isquad = Optimize.is_quad_cluster tags in
  let p = if isquad then Optimize.mk_cppq_proc else Optimize.mk_cpp_proc in
  Flyspeck_lib.output_filestring tmpfile
   (join_lines [cpp_header;interval_code;(p t s tags);cpp_tail]);;

(*
let testid = "9563139965 d";;

let idq = hd(Ineq.getexact testid);;

let [(_,tags,post)] = Optimize.preprocess_split_idq idq;;

mkfile_code false post testid tags;;
*)

(* 
   This is an enhanced version of what is in optimize.hl.
   It uses mkfile_code, which adds autogenerated interval_code to what is in Optimize.mkfile_cppq.
*)

let execute_interval ex tags s testineq = 
  let interval_dir = flyspeck_dir^"/../interval_code" in
  let _ = mkfile_code testineq s tags in
  let _ = Optimize.compile_cpp() in 
  let _ =  (not ex) or  (0=  Sys.command(interval_dir^"/test_auto")) or 
    failwith "interval execution error" in
    ();;

let testsplit_idq ex idq = 
  let splits = Optimize.preprocess_split_idq idq in
    map (fun (s,tags,testineq) -> execute_interval ex tags s testineq) splits;;

let testsplit ex s = testsplit_idq ex (hd (Ineq.getexact s));;


(* *************************************************************************** *)
(* Prep.prep_ineqs cases. *)
(* no further processing for these. *)
(* *************************************************************************** *)

let test_noprocessing_idq ex idq = 
  let (s,tags,testineq) = Optimize.idq_fields idq in
    execute_interval ex tags s testineq;;

(* let ineqs = !Prep.pre_ineqs *)

let test_prep ineqs ex s = 
  let idq = filter (fun idq -> idq.idv = s) ineqs in
    test_noprocessing_idq ex (hd idq);;

let test_prep_case_split ineqs ex (s,case,t) =
  let s' = Printf.sprintf "%s split(%d/%d)" s case t in
(*  let _ = Sys.command("sleep 3") in *)
    test_prep ineqs ex s';;


 end;;