Update from HH
[Flyspeck/.git] / formal_ineqs / verifier / m_verifier_main.hl
1 (* =========================================================== *)
2 (* Main formal verification functions                          *)
3 (* Author: Alexey Solovyev                                     *)
4 (* Date: 2012-10-27                                            *)
5 (* =========================================================== *)
6
7 needs "verifier/interval_m/verifier.ml";;
8 needs "verifier/m_verifier.hl";;
9 needs "verifier/m_verifier_build.hl";;
10 needs "taylor/m_taylor_arith2.hl";;
11 needs "misc/vars.hl";;
12
13 #load "unix.cma";;
14
15 module M_verifier_main = struct
16
17 open Arith_misc;;
18 open Interval_arith;;
19 open Eval_interval;;
20 open More_float;;
21 open M_verifier;;
22 open M_verifier_build;;
23 open M_taylor;;
24 open M_taylor_arith2;;
25 open Taylor;;
26 open Misc_vars;;
27 open Verifier_options;;
28
29
30
31 (* Parameters *)
32 type verification_parameters =
33 {
34   (* If true, then monotonicity properties can be used *)
35   (* to reduce the dimension of a problem *)
36   allow_derivatives : bool;
37   (* If true, then convexity can be used *)
38   (* to reduce the dimension of a problem *)
39   convex_flag : bool;
40   (* If true, then verification on internal subdomains can be skipped *)
41   (* for a monotone function *)
42   mono_pass_flag : bool;
43   (* If true, then raw interval arithmetic can be used *)
44   (* (without Taylor approximations) *)
45   raw_intervals_flag : bool;
46   (* If true, then an informal procedure is used to determine *)
47   (* the optimal precision for the formal verification *)
48   adaptive_precision : bool;
49   (* This parameter might be used in cases when the certificate search *)
50   (* procedure returns a wrong result due to rounding errors *)
51   (* (this parameter will be eliminated when the search procedure is corrected) *)
52   eps : float;
53 };;
54
55 let default_params =
56 {
57   allow_derivatives = true;
58   convex_flag = true;
59   mono_pass_flag = true;
60   raw_intervals_flag = true;
61   adaptive_precision = true;
62   eps = 0.0;
63 };;
64
65
66 type verification_stats =
67 {
68   total_time : float;
69   formal_verification_time : float;
70   certificate : Verifier.certificate_stats;
71 };;
72
73
74 (********************************)
75
76 (* Adds a constant approximation to the table of known constants *)
77 let add_constant_interval int_th =
78   Eval_interval.add_constant_interval int_th;
79   Informal_eval_interval.add_constant_interval int_th;;
80
81
82 (* Tests if an expression has only given binary and unary operations *)
83 let test_expression bin_ops unary_ops =
84   let rec test =
85     (* Tests if the expression is in the form `a$i` *)
86     let test_vector tm =
87       let var, index = dest_binary "$" tm in
88         dest_var var, dest_small_numeral index in
89       (* Tests if the expression is a valid binary operation *)
90     let test_binary tm =
91         try
92           let lhs, rhs = dest_comb tm in
93           let op, lhs = dest_comb lhs in
94           let c, _ = dest_const op in
95             if mem c bin_ops then (test lhs && test rhs) else false
96         with Failure _ -> false in
97       (* Tests if the expression is a valid unary operation *)
98     let test_unary tm =
99       try
100         let lhs, rhs = dest_comb tm in
101         let c, _ = dest_const lhs in
102           if mem c unary_ops then test rhs else false
103       with Failure _ -> false in
104
105       fun tm ->
106         frees tm = [] or
107           can dest_var tm or
108           can test_vector tm or
109           test_unary tm or
110           test_binary tm in
111     test;;
112
113
114 (* Tests if the given expression is a polynomial expression *)
115 let is_poly =
116   let bin_ops = ["real_add"; "real_mul"; "real_sub"; "real_pow"] in
117   let unary_ops = ["real_neg"] in
118     test_expression bin_ops unary_ops;;
119
120
121 (**********************************)
122
123 (* Creates basic verification functions *)
124 let rec mk_funs =
125   (* add *)
126   let mk_add n (f1, tf1, ti1) (f2, tf2, ti2) =
127     (fun p1 p2 x -> 
128        let a = f1 p1 p2 x and
129            b = f2 p1 p2 x in
130          eval_m_taylor_add2 n p1 p2 a b),
131     Plus (tf1, tf2),
132     (fun p1 p2 x ->
133        let a = ti1 p1 p2 x and
134            b = ti2 p1 p2 x in
135          Informal_taylor.eval_m_taylor_add p1 p2 a b) in
136     (* sub *)
137   let mk_sub n (f1, tf1, ti1) (f2, tf2, ti2) =
138     let neg_one = Interval.mk_interval(-1.0, -1.0) in
139       (fun p1 p2 x -> 
140          let a = f1 p1 p2 x and
141              b = f2 p1 p2 x in
142            eval_m_taylor_sub2 n p1 p2 a b),
143     Plus (tf1, Scale(tf2, neg_one)),
144     (fun p1 p2 x ->
145        let a = ti1 p1 p2 x and
146            b = ti2 p1 p2 x in
147          Informal_taylor.eval_m_taylor_sub p1 p2 a b) in
148     (* mul *)
149   let mk_mul n (f1, tf1, ti1) (f2, tf2, ti2) =
150     (fun p1 p2 x -> 
151        let a = f1 p1 p2 x and
152            b = f2 p1 p2 x in
153          eval_m_taylor_mul2 n p1 p2 a b),
154     Product (tf1, tf2),
155     (fun p1 p2 x ->
156        let a = ti1 p1 p2 x and
157            b = ti2 p1 p2 x in
158          Informal_taylor.eval_m_taylor_mul p1 p2 a b) in
159     (* neg *)
160   let mk_neg n (f1, tf1, ti1) =
161     (fun p1 p2 x -> 
162        let a = f1 p1 p2 x in
163          eval_m_taylor_neg2 n a),
164     Scale (tf1, Interval.mk_interval (-1.0, -1.0)),
165     (fun p1 p2 x ->
166        let a = ti1 p1 p2 x in
167          Informal_taylor.eval_m_taylor_neg a) in
168     (* sqrt *)
169   let mk_sqrt n (f1, tf1, ti1) =
170     (fun p1 p2 x -> 
171        let a = f1 p1 p2 x in
172          eval_m_taylor_sqrt2 n p1 p2 a),
173     Uni_compose (Univariate.usqrt, tf1),
174     (fun p1 p2 x ->
175        let a = ti1 p1 p2 x in
176          Informal_taylor.eval_m_taylor_sqrt p1 p2 a) in
177     (* inv *)
178   let mk_inv n (f1, tf1, ti1) =
179     (fun p1 p2 x -> 
180        let a = f1 p1 p2 x in
181          eval_m_taylor_inv2 n p1 p2 a),
182     Uni_compose (Univariate.uinv, tf1),
183     (fun p1 p2 x ->
184        let a = ti1 p1 p2 x in
185          Informal_taylor.eval_m_taylor_inv p1 p2 a) in
186     (* atn *)
187   let mk_atn n (f1, tf1, ti1) =
188     (fun p1 p2 x -> 
189        let a = f1 p1 p2 x in
190          eval_m_taylor_atn2 n p1 p2 a),
191     Uni_compose (Univariate.uatan, tf1),
192     (fun p1 p2 x ->
193        let a = ti1 p1 p2 x in
194          Informal_taylor.eval_m_taylor_atn p1 p2 a) in
195     (* acs *)
196   let mk_acs n (f1, tf1, ti1) =
197     (fun p1 p2 x -> 
198        let a = f1 p1 p2 x in
199          eval_m_taylor_acs2 n p1 p2 a),
200     Uni_compose (Univariate.uacos, tf1),
201     (fun p1 p2 x ->
202        let a = ti1 p1 p2 x in
203          Informal_taylor.eval_m_taylor_acs p1 p2 a) in
204     (* binary operations *)        
205   let bin_ops = 
206     ["real_add", mk_add;
207      "real_sub", mk_sub;
208      "real_mul", mk_mul] in
209     (* unary operations *)
210   let unary_ops =
211     ["real_neg", mk_neg;
212      "sqrt", mk_sqrt;
213      "atn", mk_atn;
214      "acs", mk_acs;
215      "real_inv", mk_inv] in
216     (* makes a binary operation *)
217   let mk_bin n pp x_var tm =
218     let lhs, rhs = dest_comb tm in
219     let op, lhs = dest_comb lhs in
220     let mk_f = assoc ((fst o dest_const) op) bin_ops in
221     let l_funs = mk_funs n pp (mk_abs(x_var, lhs)) and
222         r_funs = mk_funs n pp (mk_abs(x_var, rhs)) in
223       mk_f n l_funs r_funs in
224     (* makes an unary operation *)
225   let mk_unary n pp x_var tm =
226     let op, rhs = dest_comb tm in
227     let mk_f = assoc ((fst o dest_const) op) unary_ops in
228     let funs = mk_funs n pp (mk_abs(x_var, rhs)) in
229       mk_f n funs in
230     (* the main function *)
231     fun n pp fun_tm ->
232       let x_var, body_tm = dest_abs fun_tm in
233         if is_poly body_tm then
234           let eval_fs, tf, ti = mk_verification_functions_poly pp fun_tm in
235             eval_fs.taylor, tf, ti.Informal_verifier.taylor
236         else
237           try mk_bin n pp x_var body_tm with Failure _ ->
238             mk_unary n pp x_var body_tm;;
239
240            
241 (* Prepares verification functions *)
242 (* fun_tm must be in the form `\x. f x` *)
243 let mk_verification_functions =
244   let dummy_f pp lo hi = failwith "dummy f" and
245       dummy_df i pp lo hi = failwith "dummy df" and
246       dummy_ddf i j pp lo hi = failwith "dummy ddf" and
247       dummy_diff2 lo hi = failwith "dummy diff2" in
248     fun params pp fun_tm ->
249       let x_var, body_tm = dest_abs fun_tm in
250         if is_poly body_tm then
251           mk_verification_functions_poly pp fun_tm
252         else
253           let n = get_dim x_var in
254           let eval_taylor, tf, eval_ti = mk_funs n pp fun_tm in
255           let _ = params := {!params with raw_intervals_flag = false; convex_flag = false} in
256             {taylor = eval_taylor; 
257              f = dummy_f; df = dummy_df; ddf = dummy_ddf; diff2_f = dummy_diff2}, tf,
258       {Informal_verifier.taylor = eval_ti;
259        Informal_verifier.f = dummy_f;
260        Informal_verifier.df = dummy_df;
261        Informal_verifier.ddf = dummy_ddf};;
262
263    
264 (********************************)
265
266 let convert_to_float_list pp lo_flag list_tm =
267   let tms = dest_list list_tm in
268   let i_funs = map build_interval_fun tms in
269   let ints = map (fun f -> eval_interval_fun pp f [] []) i_funs in
270   let extract = (if lo_flag then fst else snd) o dest_pair o rand o concl in
271     mk_list (map extract ints, real_ty);;
272
273
274 (* Creates a theorem |- interval[xx_tm, zz_tm] SUBSET interval[float(xx_tm), float(zz_tm)]
275    and two lists: float(xx_tm) and float(zz_tm) *)
276 let mk_float_domain pp (xx_tm, zz_tm) =
277   let xx_list = dest_list xx_tm and
278       zz_list = dest_list zz_tm in
279   let n = length xx_list in 
280   let get_intervals tms =
281     let i_funs = map build_interval_fun tms in
282       map (fun f -> eval_interval_fun pp f [] []) i_funs in
283   let xx_ints = get_intervals xx_list and
284       zz_ints = get_intervals zz_list in
285   let xx_ineqs = map (CONJUNCT1 o ONCE_REWRITE_RULE[interval_arith]) xx_ints and
286       zz_ineqs = map (CONJUNCT2 o ONCE_REWRITE_RULE[interval_arith]) zz_ints in
287   let a_vals = map (lhand o concl) xx_ineqs and
288       b_vals = map (rand o concl) zz_ineqs in
289   let a_vars = mk_real_vars n "a" and
290       b_vars = mk_real_vars n "b" and
291       c_vars = mk_real_vars n "c" and
292       d_vars = mk_real_vars n "d" in
293   let th0 = (INST (zip xx_list c_vars) o INST (zip zz_list d_vars) o
294                INST (zip a_vals a_vars) o INST (zip b_vals b_vars)) 
295     subset_interval_thms_array.(n) in
296     itlist MY_PROVE_HYP (xx_ineqs @ zz_ineqs) th0, 
297   (mk_list (a_vals, real_ty), mk_list (b_vals, real_ty));;
298
299
300
301 (* Given a term a < b, returns the theorem |- a - b < &0 <=> a < b *)
302 (* Also, deals with > and / *)
303 (* A user can provide additional rewrite theorems *)
304 let mk_standard_ineq =
305   let lemma = REAL_ARITH `a < b <=> a - b < &0` in
306     fun thms tm ->
307       let th0 = (REWRITE_CONV([real_gt; real_div] @ thms) THENC DEPTH_CONV let_CONV) tm in
308       let rhs = rand (concl th0) in
309       let th1 = (ONCE_REWRITE_CONV[lemma] THENC PURE_REWRITE_CONV[REAL_NEG_0; REAL_SUB_RZERO; REAL_SUB_LZERO]) rhs in
310         TRANS th0 th1;;
311
312
313 (* Converts a term in the form `x + y` into the term `\x:real^2. x$1 + x$2` *)
314 let expr_to_vector_fun =
315   let comp_op = `$` in
316     fun expr_tm ->
317       let vars = List.sort Pervasives.compare (frees expr_tm) in
318       let n = length vars in
319       let x_var = mk_var ("x", n_vector_type_array.(if n = 0 then 1 else n)) in
320       let x_tm = mk_icomb (comp_op, x_var) in
321       let vars2 = map (fun i -> mk_comb (x_tm, mk_small_numeral i)) (1--n) in
322         mk_abs (x_var, subst (zip vars2 vars) expr_tm), 
323       (if n = 0 then mk_vector_list [x_var] else mk_vector_list vars);;
324
325
326 (* Given an inequality `P x y`, variable names and the corresponding bounds,
327    yields `(x0 <= x /\ x <= x1) /\ (y0 <= y /\ y <= y1) ==> P x y` *)
328 let mk_ineq ineq_tm names dom_tm =
329   let lo_list = dest_list (fst dom_tm) and
330       hi_list = dest_list (snd dom_tm) in
331   let vars = map (fun name -> mk_var (name, real_ty)) names in
332   let lo_ineqs = map2 (fun tm1 tm2 -> mk_binop le_op_real tm1 tm2) lo_list vars and
333       hi_ineqs = map2 (fun tm1 tm2 -> mk_binop le_op_real tm1 tm2) vars hi_list in
334   let ineqs = map2 (fun tm1 tm2 -> mk_conj (tm1, tm2)) lo_ineqs hi_ineqs in
335   let cond = end_itlist (curry mk_conj) ineqs in
336     mk_imp (cond, ineq_tm);;
337
338
339 (* Reverts the effect of mk_ineq function *)
340 let dest_ineq ineq_tm =
341   if frees ineq_tm = [] then
342     ineq_tm, [], (real_empty_list, real_empty_list)
343   else
344     let tm0 = (rand o concl o PURE_REWRITE_CONV[IMP_IMP; GSYM CONJ_ASSOC]) ineq_tm in
345     let cond, ineq = dest_imp tm0 in
346     let conds = striplist dest_conj cond in
347     let ineqs = ref [] in
348     let decode_ineq tm =
349       let lhs, rhs = dest_binop le_op_real tm in
350       let lo_flag = (frees lhs = []) in
351       let name = (fst o dest_var) (if lo_flag then rhs else lhs) in
352       let val_ref = 
353         try assoc name !ineqs
354         with Failure _ -> 
355           let val_ref = ref (x_var_real, x_var_real) in 
356             ineqs := ((name, val_ref) :: !ineqs); val_ref in
357         val_ref := if lo_flag then (lhs, snd !val_ref) else (fst !val_ref, rhs) in
358     let _ = map (fun tm -> 
359                    (try decode_ineq tm with Failure _ ->
360                       failwith ("Bad variable bound inequality: "^string_of_term tm))) conds in
361     let names, bounds0 = unzip !ineqs in
362     let lo, hi = unzip (map (fun r -> !r) bounds0) in
363     let test_bounds bounds bound_name =
364       let _ = map2 (fun tm name -> if frees tm <> [] then 
365                       failwith (bound_name^" bound is not defined for "^name) else ())
366         bounds names in () in
367     let _ = test_bounds hi "Upper"; test_bounds lo "Lower" in
368       ineq, names, (mk_real_list lo, mk_real_list hi);;
369
370 (*********************************)
371
372 (* Normalizes a verification result *)
373 let normalize_result norm_flag v1 eq_th1 domain_sub_th pass_thm =
374   let th0 = REWRITE_RULE[m_cell_pass] pass_thm in
375   let n = (get_dim o fst o dest_forall o concl) th0 in
376   let th1 = SPEC v1 th0 in
377   let comp_thms = end_itlist CONJ (Array.to_list comp_thms_array.(n)) in
378   let th2 = REWRITE_RULE[comp_thms] th1 in
379   let th3 = (UNDISCH_ALL o REWRITE_RULE[GSYM eq_th1]) th2 in
380   let dom_th = (UNDISCH_ALL o SPEC v1 o REWRITE_RULE[SUBSET]) domain_sub_th in
381   let th4 = (DISCH_ALL o MY_PROVE_HYP dom_th) th3 in
382   let th5 = REWRITE_RULE[IN_INTERVAL; dimindex_array.(n); gen_in_interval n; comp_thms] th4 in
383     if norm_flag then GEN_ALL th5 else th4;;
384
385
386 (* Verifies the given inequality *)
387 (* Returns the final theorem and verification statistics *)
388 let verify_ineq0 params0 norm_flag pp ineq_tm var_names (lo_tm, hi_tm) rewrite_thms =
389   let total_start = Unix.gettimeofday() in
390   let eq_th1 = mk_standard_ineq rewrite_thms ineq_tm in
391   let ineq_tm1 = (lhand o rand o concl) eq_th1 in
392     if frees ineq_tm1 = [] then
393       let i_fun = build_interval_fun ineq_tm1 in
394       let th0 = eval_interval_fun pp i_fun [] [] in
395       let th1 = float_interval_lt0 th0 in
396       let total = Unix.gettimeofday() -. total_start in
397         REWRITE_RULE[GSYM eq_th1] th1, 
398   {total_time = total; formal_verification_time = total; certificate = Verifier.dummy_stats}
399     else
400       let fun_tm, v1 = expr_to_vector_fun ineq_tm1 in
401       let vars = map (fst o dest_var) (dest_vector v1) in
402       let lo_list = dest_list lo_tm and
403           hi_list = dest_list hi_tm in
404       let bounds0 = zip var_names (zip lo_list hi_list) in
405       let bounds = itlist (fun name list -> assoc name bounds0 :: list) vars [] in
406       let xx, zz = unzip bounds in
407       let xx, zz = mk_real_list xx, mk_real_list zz in
408         
409       let domain_sub_th, (xx1, zz1) = mk_float_domain pp (xx, zz) in
410       let n = (get_dim o fst o dest_abs) fun_tm in
411       let xx2 = Informal_taylor.convert_to_float_list pp true xx and
412           zz2 = Informal_taylor.convert_to_float_list pp false zz in
413       let xx_float = map float_of_float_tm (dest_list xx1) and
414           zz_float = map float_of_float_tm (dest_list zz1) in
415         
416       let params = ref params0 in
417       let eval_fs, tf, ti = mk_verification_functions params pp fun_tm in
418         
419       let _ = !info_print_level < 1 or (report0 "Constructing a solution certificate... "; true) in
420       let certificate = Verifier.run_test tf xx_float zz_float false 0.0 
421         !params.allow_derivatives !params.convex_flag !params.mono_pass_flag
422         !params.raw_intervals_flag !params.eps in
423       let stats = Verifier.result_stats certificate in
424       let _ = !info_print_level < 1 or (report0 " done\n"; true) in
425       let _ = !info_print_level < 1 or (Verifier.report_stats stats; true) in
426
427       let c1 = Verifier.transform_result xx_float zz_float certificate in
428       let start, finish, result = 
429         if !params.adaptive_precision then
430           let _ = !info_print_level < 1 or (report0 "Informal verification... "; true) in
431           let c1p = Informal_verifier.m_verify_list pp 1 pp ti c1 xx2 zz2 in
432           let _ = !info_print_level < 1 or (report0 " done\n"; true) in
433
434           let _ = !info_print_level < 1 or (report0 "Formal verification... "; true) in
435           let start = Unix.gettimeofday() in
436           let result = m_p_verify_list n pp eval_fs c1p xx1 zz1 in
437           let finish = Unix.gettimeofday() in
438           let _ = !info_print_level < 1 or (report0 " done\n"; true) in
439             start, finish, result
440         else
441           let _ = !info_print_level < 1 or (report0 "Formal verification... "; true) in
442           let start = Unix.gettimeofday() in
443           let result = m_verify_list n pp eval_fs c1 xx1 zz1 in
444           let finish = Unix.gettimeofday() in
445           let _ = !info_print_level < 1 or (report0 " done\n"; true) in
446             start, finish, result in
447         normalize_result norm_flag v1 eq_th1 domain_sub_th result,
448   {total_time = finish -. total_start; formal_verification_time = finish -. start; certificate = stats};;
449         
450
451 (* A simple verification function which accepts 
452    a list of rewrite theorems which are applied to the inequality
453    before verification *)
454 let verify_ineq_and_rewrite rewrite_thms params pp ineq_tm =
455   let ineq, vars, bounds = dest_ineq ineq_tm in
456     verify_ineq0 params true pp ineq vars bounds rewrite_thms;;
457
458
459 (* The simplest verification function *)
460 let verify_ineq = verify_ineq_and_rewrite [];;
461
462
463 end;;