Update from HH
[Flyspeck/.git] / text_formalization / nonlinear / auto_lib.hl
1 (* ========================================================================= *)
2 (* FLYSPECK - BOOK FORMALIZATION                                             *)
3 (*                                                                           *)
4 (* Chapter: nonlinear inequalities                                           *)
5 (* Author:  Thomas Hales      *)
6 (* Date: 2012-06-02                                                          *)
7 (* ========================================================================= *)
8
9 (* 
10    Generates
11      ocaml function module (Sphere2).
12      C++ Functions from HOL-Light specs.
13      C++ Interval code for inequalities from HOL-Light specs.
14
15    It uses many lemmas from functional_equation.hl
16
17 *)
18
19
20 (* to fix:
21    proj_y1 = sqrt_x1 ,etc.
22
23    to reorganize:
24    move optimize.hl material on C++ generation here.
25    Leave optimize.hl as the preprocessing module.
26 *)
27
28
29 flyspeck_needs "general/flyspeck_lib.hl";;
30 flyspeck_needs  "nonlinear/functional_equation.hl";;
31 flyspeck_needs "nonlinear/optimize.hl";;
32 flyspeck_needs  "nonlinear/parse_ineq.hl";;
33 flyspeck_needs "nonlinear/function_list.hl";;
34
35 module Auto_lib = struct
36
37   let join_comma = Flyspeck_lib.join_comma;;
38   let join_space = Flyspeck_lib.join_space;;
39   let join_lines = Flyspeck_lib.join_lines;;
40   let functions = Function_list.functions();;
41
42 let ocaml_code = 
43   let strip_all b = snd(strip_forall (concl (Nonlinear_lemma.strip_let b))) in
44   let ocam f = Parse_ineq.ocaml_function (strip_all f) in
45   let header = 
46   "(* code automatically generated from Parse_ineq.ocaml_code *)\n\n"^
47    "module Sphere2 = struct\n\n"^
48    "let sqrt = Pervasives.sqrt;;\n\n" ^
49    "let safesqrt = Pervasives.sqrt;;\n\n" ^
50    "let cos = Pervasives.cos;;\n\n" ^
51    "let sin = Pervasives.sin;;\n\n" ^
52    "let log = Pervasives.log;;\n\n" ^
53    "let asn = Pervasives.asin;;\n\n" ^
54    "let atn = Pervasives.atan;;\n\n" ^
55     "let hminus = 1.2317544220903043901;;\n\n" ^
56     "let pow2 x = x ** (2.0);;\n\n" in
57   let tailer =    "end;;\n" in
58     header ^
59    (Flyspeck_lib.join_lines (map ocam functions)) ^ tailer;;
60
61
62 (* Load module Sphere2 *)
63
64 let sphere2_ml = Filename.temp_file "sphere2" ".ml";; 
65
66 Flyspeck_lib.output_filestring sphere2_ml ocaml_code;;
67
68 loadt sphere2_ml;;
69
70
71 let break_functional_lemma thm = 
72   let strip_all b = snd(strip_forall (concl (Nonlinear_lemma.strip_let b))) in
73   let (h,ts) = strip_comb (strip_all thm) in
74   let isdomain h = (fst(dest_const h) = "domain6") in
75   let namebody = if isdomain h then tl ts else ts in
76    (List.nth namebody 0),List.nth namebody 1;;
77
78 let break_term x = break_functional_lemma (ASSUME x);;
79
80 let rec real_arity ty = 
81   let real_ty = `:real` in
82   if (is_vartype ty) then 0
83   else if (ty = real_ty) then 1 else
84       let (a,b) = dest_type ty in
85         if not(a = "fun" && hd b = real_ty && List.length b = 2) then 0
86         else 1 + real_arity (hd(tl b));;
87
88 let mk_testing_string thm = 
89   let native = Optimize.native_fun in
90   let (name,body') = break_functional_lemma thm in
91   let name' = fst (strip_comb name) in
92   let name' = fst (dest_const name') in
93   let cname = Lib.assocd name' native name' in
94   let _ = not(cname="NOT_IMPLEMENTED") or failwith "mk_testing:excluded" in
95   let evalname = 
96     let s = Printf.sprintf "Sphere2.%s 6.36 4.2 4.3 4.4 4.5 4.6" name' in
97     let (b,s')=  Flyspeck_lib.eval_command s in
98     let _ = b or failwith ("evalname: "^name') in
99     let split = Str.split (Str.regexp "[ \n]") in
100     let r = hd (List.rev  (split s')) in
101     let f = float_of_string r in
102     let _ = not(0=Pervasives.compare nan f) or failwith "nan" in
103       f in
104   let mk_string = 
105     Printf.sprintf "  epsValue(\"%s\",%s,%12.12f);" name' cname evalname in
106     mk_string;;
107
108 let mk_n_testing_string thm = 
109   let native = Optimize.native_fun in
110   let (name,body') = break_functional_lemma thm in
111   let name' = fst (strip_comb name) in
112   let (name',ty) = dest_const name' in
113   let rarity = real_arity ty in
114   let nargs = rarity - 7 in 
115   let _ = (rarity >= 8) or failwith "mk_n_testing_string" in 
116   let args = map (fun i -> 0.04 +. (float_of_int i)/. 10.0) (1--nargs) in 
117   let os = join_space (map (Printf.sprintf "%f") args) in
118       (* was interval::interval(".."). Changed 2013/08/14. *)
119   let sargs = join_comma (map (Printf.sprintf "interval(\"%f\")") args) in
120   let cname = Lib.assocd name' native name' in
121   let _ = not(cname="NOT_IMPLEMENTED") or failwith "mk_testing:excluded" in
122   let evalname = 
123     let s = Printf.sprintf "Sphere2.%s %s 6.36 4.2 4.3 4.4 4.5 4.6" name' os in
124     let (b,s')=  Flyspeck_lib.eval_command s in
125     let _ = b or failwith ("evalname: "^name') in
126     let split = Str.split (Str.regexp "[ \n]") in
127     let r = hd (List.rev  (split s')) in
128     let f = float_of_string r in
129     let _ = not(0=Pervasives.compare nan f) or failwith "nan" in
130       f in
131   let mk_string = 
132     Printf.sprintf "  epsValue(\"%s\",%s(%s),%12.12f);" name' cname sargs evalname in
133     mk_string;;
134
135 let all_testing_string = 
136   let can_test = filter (can  mk_testing_string) functions in
137   let can_test_n = filter (can mk_n_testing_string) functions in
138     Flyspeck_lib.join_lines (
139       (map mk_testing_string can_test) @ (map mk_n_testing_string can_test_n));;
140
141 let testing_code = 
142   Printf.sprintf 
143    "\nvoid selfTest() { 
144        cout << \" -- loading test_auto test\" << endl << flush;\n 
145        %s 
146        cout << \" -- done loading test_auto test\" << endl << flush; }\n" 
147     all_testing_string;;
148
149 let not_tested = filter (fun t -> 
150                           not (can mk_testing_string t) &&
151                         not (can mk_n_testing_string t)) functions;;
152
153 (* following is copied and adapted from optimize.hl *)
154
155 let paren s = "("^s^")";;
156 let i_mk = Optimize.i_mk;;
157 let string_of_num' = Optimize.string_of_num';;
158 let dest_decimal = Optimize.dest_decimal;;
159
160 let real_ty = `:real`;;
161
162 let f1_ty = `:real->real`;;
163
164 let f6_ty = `:real->real->real->real->real->real->real` ;;
165
166
167 let f7_ty = `:real->real->real->real->real->real->real->real` ;;
168 let f8_ty = `:real->real->real->real->real->real->real->real->real` ;;
169 let f9_ty = `:real->real->real->real->real->real->real->real->real->real` ;;
170
171
172 real_arity `:real->real->real->real->real` =  5;;
173
174 let f6to6_ty = `:(real->real->real->real->real->real->real) ->
175   (real->real->real->real->real->real->real)`;;
176
177 let infix6_ty = `:(real->real->real->real->real->real->real) ->
178   (real->real->real->real->real->real->real) ->
179   (real->real->real->real->real->real->real)` ;;
180
181 let scalar6_ty = `:(real->real->real->real->real->real->real) ->
182   (real) ->
183   (real->real->real->real->real->real->real)` ;;
184
185 let tyvar_inst = 
186   let realty = `:real` in
187   let u = 
188     setify(List.flatten (map (type_vars_in_term o concl) functions)) in
189     map (fun t-> (realty,t)) u;;
190
191 type_of (inst [(`:real`,`:A`)] `x:A`);;
192
193 let nonnative_functional_terms = 
194   let f = map ((inst tyvar_inst)o concl) functions in
195   let native = map fst Optimize.native_fun in 
196   let name t =   fst(dest_const(fst(strip_comb(fst(break_term t))))) in
197   let m t = not (mem (name t) native) in 
198     filter m f ;;
199
200 let real_types = setify(map (type_of o fst o strip_comb o fst o break_term) 
201                           nonnative_functional_terms);;
202
203 List.length real_types;;
204
205 let terms_with_type ty = 
206   filter (fun t -> ty = type_of(fst(strip_comb(
207     fst(break_term t))))) nonnative_functional_terms;;
208
209 let terms_with_real_arity_ge8 = 
210   filter (fun t -> 8 <= real_arity (type_of(fst(strip_comb(
211     fst(break_term t)))))) nonnative_functional_terms;;
212
213 let f0_terms =  (terms_with_type real_ty);;
214
215 let f0_code = 
216   let f0_template = Printf.sprintf
217     "static const interval %s (\"%20.20f\");" in
218   let f0_mk thm = 
219     let (name,body') = break_functional_lemma thm in
220     let name' = fst (strip_comb name) in
221     let name' = fst (dest_const name') in
222     let s = Printf.sprintf "Sphere2.%s" name' in
223     let (b,s')=  Flyspeck_lib.eval_command s in
224     let _ = b or failwith ("evalname: "^name') in
225     let split = Str.split (Str.regexp "[ \n]") in
226     let r = hd (List.rev  (split s')) in
227     let warn = "// Warning: "^name'^" computed by floating point\n" in
228     let r' =   float_of_string r in
229       warn^(f0_template name' r') in
230     Flyspeck_lib.join_lines (map (f0_mk o ASSUME) f0_terms);;
231
232 let native_fun = Optimize.native_fun;;
233
234 let native_infix = [
235   ("add6","+");
236   ("mul6","*");
237   ("sub6","-");
238   ("div6","/");
239   ("scalar6","*");
240 ];;
241
242 let native_interval = [
243   ("hminus","hminus")
244 ];;
245
246 let f0_name   = 
247   let f0_auto = map (fst o dest_const o fst o break_term) f0_terms in
248   fun s ->
249   if (mem s f0_auto) then s
250       else 
251         try (Lib.assoc s native_interval) with 
252             Failure _ -> failwith (s^" find: real_name") ;;
253
254 let fun_name = 
255   let fun_auto = map (fst o dest_const o fst o strip_comb o fst o break_term)
256     nonnative_functional_terms in
257     fun s->
258       try (Lib.assoc s native_fun) with 
259           Failure _ -> 
260             if (mem s fun_auto) then s else failwith ("fun_name not found: "^s);;
261
262 let is_comma = 
263   let c = "," in
264     fun t ->
265       let (t,_) = strip_comb t in
266         (is_const t && fst (dest_const t) = c);;
267
268 let cpp_string_of_term = 
269   let rec soh t = 
270     if is_var t then fst (dest_var t) else
271       let (f,xs) = strip_comb t in
272       let ifix i = let [a;b] = xs in paren(soh a ^ " " ^ i ^ " " ^ soh b) in
273       let (fv,ty) = 
274         if is_var f
275         then 
276           let (fv,ty) = (dest_var f) in
277           let _ = warn true ("variable function name: "^fv) in
278             (fv,ty)
279         else if (is_const f) then (dest_const f)
280         else
281           failwith ("var/const expected:" ^ string_of_term f) in
282         match fv with
283           | "real_add" -> ifix "+"
284           | "real_mul" -> ifix "*"
285           | "real_div" -> ifix "/"
286           | "real_sub" -> ifix "-"
287           | "," -> ifix ","
288           | "\\/" -> ifix "\\/"
289           | "real_neg" -> let [a] = xs in "(-" ^ soh a ^ ")"
290           | "real_of_num" -> let [a] = xs in i_mk(soh a)  
291           | "NUMERAL" -> let [_] = xs in string_of_num' (dest_numeral t)
292           | "<" -> let [a;b] = xs in paren(soh a ^ " < " ^ soh b)
293           | ">" -> let [a;b] = xs in paren(soh a ^ " > " ^ soh b)
294           | "+" -> let [a;b] = xs in paren(soh a ^ " + " ^ soh b)
295           | "*" -> let [a;b] = xs in paren(soh a ^ " * " ^ soh b)
296           | "DECIMAL" ->  i_mk(string_of_num' (dest_decimal t))
297           | _ -> 
298               if (ty = real_ty) 
299               then paren(f0_name fv)
300               else if (ty= infix6_ty) or (ty=scalar6_ty)
301               then
302                 let op = 
303                   (try Lib.assoc fv native_infix
304                    with Failure _ -> failwith ("parse infix6 "^fv)) in
305                   ifix op
306               else 
307                 (let name = fun_name fv in
308                    if (xs=[]) then paren name else
309                      let p = if (List.length xs = 1 && is_comma (hd xs))
310                      then I else paren in
311                      let args = p (join_comma (map soh xs)) in
312                        paren (name^args)) in
313     fun t -> 
314       try (soh t) 
315       with Failure s -> failwith (s^" .......   "^string_of_term t);;
316
317 (* make functions of 6 variables *)
318
319
320 let f6_code = 
321   let f6_template = Printf.sprintf
322     "static const Function %s = %s;\n" in
323   let f6_terms = (terms_with_type f6_ty) in
324   let f6_auto = 
325     let b = (fst o dest_const o fst o strip_comb o fst o break_term) in
326     let nat = map fst native_fun in
327       filter (fun t -> not (mem (b t) nat)) f6_terms in
328   let f6_mk tt = 
329     let (name1,body') = break_term tt in
330     let name' = fst (strip_comb name1) in
331     let name' = fst (dest_const name') in
332       f6_template name' (cpp_string_of_term body') in
333     join_lines (map f6_mk f6_auto);;
334
335 let fn_code = 
336   let fn_template = Printf.sprintf
337     "static const Function %s(%s) { return (%s); }\n" in
338   let fn_arg_template = Printf.sprintf
339     "const interval& %s" in
340   let fn_terms = terms_with_real_arity_ge8 in  (* (terms_with_type f7_ty)@
341     (terms_with_type f8_ty) @ (terms_with_type f9_ty) in *)
342   let fn_auto = 
343     let b = (fst o dest_const o fst o strip_comb o fst o break_term) in
344     let nat = map fst native_fun in
345       filter (fun t -> not (mem (b t) nat)) fn_terms in
346   let fn_mk tt = 
347     let (name1,body') = break_term tt in
348     let (name',args) =  (strip_comb name1) in
349     let ags = join_comma (map (fn_arg_template o fst o dest_var) args) in
350     let name' = fst (dest_const name') in
351       fn_template name' ags (cpp_string_of_term body') in
352     join_lines (map fn_mk fn_auto);;
353
354 (* make 6 to 6 *)
355
356 let f6to6_template = Printf.sprintf
357   "static const Function %s(const Function& %s) {
358      return %s;
359   }\n";;
360
361 let f6to6_terms = terms_with_type f6to6_ty;;
362
363 let f6to6_mk tt =
364   let (name1,body') = break_term tt in
365   let (name',param) = strip_comb name1 in
366   let _ = List.length param = 1 or 
367     failwith ("one parameter expected "^ string_of_term tt) in
368   let name' = fst (dest_const name') in
369   let param = fst (dest_var (hd param)) in
370     f6to6_template name' param (cpp_string_of_term body');;  
371
372 let f6to6_code = join_lines (map f6to6_mk f6to6_terms);;
373
374 let tmpfile = flyspeck_dir^"/../interval_code/test_auto.cc";;
375
376 let interval_code =
377   f0_code  ^ f6to6_code ^ f6_code ^ fn_code^
378   testing_code;;
379
380 (* based on optimize, but is enhanced with autogenerated code interval_code.  *)
381
382 let mkfile_code t s tags  = 
383   let cpp_header = Optimize.cpp_header() in
384   let cpp_tail = Optimize.cpp_tail() in
385   let isquad = Optimize.is_quad_cluster tags in
386   let p = if isquad then Optimize.mk_cppq_proc else Optimize.mk_cpp_proc in
387   Flyspeck_lib.output_filestring tmpfile
388    (join_lines [cpp_header;interval_code;(p t s tags);cpp_tail]);;
389
390 (*
391 let testid = "9563139965 d";;
392
393 let idq = hd(Ineq.getexact testid);;
394
395 let [(_,tags,post)] = Optimize.preprocess_split_idq idq;;
396
397 mkfile_code false post testid tags;;
398 *)
399
400 (* 
401    This is an enhanced version of what is in optimize.hl.
402    It uses mkfile_code, which adds autogenerated interval_code to what is in Optimize.mkfile_cppq.
403 *)
404
405 let execute_interval ex tags s testineq = 
406   let interval_dir = flyspeck_dir^"/../interval_code" in
407   let _ = mkfile_code testineq s tags in
408   let _ = Optimize.compile_cpp() in 
409   let _ =  (not ex) or  (0=  Sys.command(interval_dir^"/test_auto")) or 
410     failwith "interval execution error" in
411     ();;
412
413 let testsplit_idq ex idq = 
414   let splits = Optimize.preprocess_split_idq idq in
415     map (fun (s,tags,testineq) -> execute_interval ex tags s testineq) splits;;
416
417 let testsplit ex s = testsplit_idq ex (hd (Ineq.getexact s));;
418
419
420 (* *************************************************************************** *)
421 (* Prep.prep_ineqs cases. *)
422 (* no further processing for these. *)
423 (* *************************************************************************** *)
424
425 let test_noprocessing_idq ex idq = 
426   let (s,tags,testineq) = Optimize.idq_fields idq in
427     execute_interval ex tags s testineq;;
428
429 (* let ineqs = !Prep.pre_ineqs *)
430
431 let test_prep ineqs ex s = 
432   let idq = filter (fun idq -> idq.idv = s) ineqs in
433     test_noprocessing_idq ex (hd idq);;
434
435 let test_prep_case_split ineqs ex (s,case,t) =
436   let s' = Printf.sprintf "%s split(%d/%d)" s case t in
437 (*  let _ = Sys.command("sleep 3") in *)
438     test_prep ineqs ex s';;
439
440
441  end;;