Update from HH
[Flyspeck/.git] / formal_ineqs / arith / eval_interval.hl
1 (* =========================================================== *)
2 (* Formal interval evaluation of arithmetic expressions        *)
3 (* Author: Alexey Solovyev                                     *)
4 (* Date: 2012-10-27                                            *)
5 (* =========================================================== *)
6
7 needs "arith/more_float.hl";;
8 needs "arith/float_atn.hl";;
9 needs "misc/vars.hl";;
10
11
12 module Eval_interval = struct
13
14 open Arith_misc;;
15 open Interval_arith;;
16 open Arith_float;;
17 open Float_atn;;
18 open More_float;;
19 open Misc_vars;;
20
21 (* Creates an interval approximation of the given decimal term *)
22 let mk_float_interval_decimal =
23   let DECIMAL' = SPEC_ALL DECIMAL in
24     fun pp decimal_tm ->
25       let n_tm, d_tm = dest_binary "DECIMAL" decimal_tm in
26       let n, d = dest_numeral n_tm, dest_numeral d_tm in
27       let n_int, d_int = mk_float_interval_num n, mk_float_interval_num d in
28       let int = float_interval_div pp n_int d_int in
29       let eq_th = INST[n_tm, x_var_num; d_tm, y_var_num] DECIMAL' in
30         norm_interval int eq_th;;
31
32
33 (* Unary interval operations *)
34 let unary_interval_operations = 
35   let table = Hashtbl.create 10 in
36   let add = Hashtbl.add in
37     add table `--` (fun pp -> float_interval_neg);
38     add table `inv` float_interval_inv;
39     add table `sqrt` float_interval_sqrt;
40     add table `atn` float_interval_atn;
41     add table `acs` float_interval_acs;
42     table;;
43
44
45 (* Binary interval operations *)
46 let binary_interval_operations = 
47   let table = Hashtbl.create 10 in
48   let add = Hashtbl.add in
49     add table `+` float_interval_add;
50     add table `-` float_interval_sub;
51     add table `*` float_interval_mul;
52     add table `/` float_interval_div;
53     table;;
54
55
56 (* Interval approximations of constants *)
57 let interval_constants =
58   let table = Hashtbl.create 10 in
59   let add = Hashtbl.add in
60     add table `pi` (fun pp -> pi_approx_array.(pp));
61     table;;
62
63
64
65 (* Type of an interval function *)
66 type interval_fun =
67   | Int_ref of int
68   | Int_var of term
69   | Int_const of thm
70   | Int_decimal_const of term
71   | Int_named_const of term
72   | Int_pow of int * interval_fun
73   | Int_unary of term * interval_fun
74   | Int_binary of term * interval_fun * interval_fun;;
75
76
77 (* Evaluates the given interval function at the point
78    defined by the given list of variables *)
79 let eval_interval_fun pp ifun vars refs =
80   let u_find = Hashtbl.find unary_interval_operations and
81       b_find = Hashtbl.find binary_interval_operations and
82       c_find = Hashtbl.find interval_constants in
83   let rec rec_eval f =
84     match f with
85       | Int_ref i -> List.nth refs i
86       | Int_var tm -> assoc tm vars
87       | Int_const th -> th
88       | Int_decimal_const tm -> mk_float_interval_decimal pp tm
89       | Int_named_const tm -> c_find tm pp
90       | Int_pow (n,f1) -> float_interval_pow_simple pp n (rec_eval f1)
91       | Int_unary (tm,f1) -> u_find tm pp (rec_eval f1)
92       | Int_binary (tm,f1,f2) -> b_find tm pp (rec_eval f1) (rec_eval f2) in
93     rec_eval ifun;;
94
95
96 (* Evaluates all sub-expressions involving constants in the given interval function *)
97 let eval_constants pp ifun =
98   let u_find = Hashtbl.find unary_interval_operations and
99       b_find = Hashtbl.find binary_interval_operations and
100       c_find = Hashtbl.find interval_constants in
101   let rec rec_eval f =
102     match f with
103       | Int_decimal_const tm -> Int_const (mk_float_interval_decimal pp tm)
104       | Int_named_const tm -> Int_const (c_find tm pp)
105       | Int_pow (n,f1) -> 
106           (let f1_val = rec_eval f1 in
107              match f1_val with
108                | Int_const th -> Int_const (float_interval_pow_simple pp n th)
109                | _ -> Int_pow (n,f1_val))
110       | Int_unary (tm,f1) ->
111           (let f1_val = rec_eval f1 in
112              match f1_val with
113                | Int_const th -> Int_const (u_find tm pp th)
114                | _ -> Int_unary (tm, f1_val))
115       | Int_binary (tm,f1,f2) ->
116           (let f1_val, f2_val = rec_eval f1, rec_eval f2 in
117              match f1_val with
118                | Int_const th1 ->
119                    (match f2_val with
120                       | Int_const th2 -> Int_const (b_find tm pp th1 th2)
121                       | _ -> Int_binary (tm, f1_val, f2_val))
122                | _ -> Int_binary (tm, f1_val, f2_val))
123           | _ -> f in
124     rec_eval ifun;;
125
126
127
128 (**************************************)
129
130 (* Builds an interval function from the given term *)
131 let rec build_interval_fun expr_tm =
132   if is_const expr_tm then
133     (* Constant *)
134     Int_named_const expr_tm
135   else if is_var expr_tm then
136     (* Variable *)
137     Int_var expr_tm
138   else
139     let ltm, r_tm = dest_comb expr_tm in
140       (* Unary operations *)
141       if is_const ltm then
142         (* & *)
143         if ltm = amp_op_real then
144           let n = dest_numeral r_tm in
145             Int_const (mk_float_interval_num n)
146         else 
147           let r_fun = build_interval_fun r_tm in
148             Int_unary (ltm, r_fun)
149       else
150         (* Binary operations *)
151         let op, l_tm = dest_comb ltm in
152         let name = (fst o dest_const) op in
153           if name = "DECIMAL" then
154             (* DECIMAL *)
155             Int_decimal_const expr_tm
156           else if name = "real_pow" then
157             (* pow *)
158             let n = dest_small_numeral r_tm in
159               Int_pow (n, build_interval_fun l_tm)
160           else if name = "$" then
161             (* $ *)
162             Int_var expr_tm
163           else
164             let lhs = build_interval_fun l_tm and
165                 rhs = build_interval_fun r_tm in
166               Int_binary (op, lhs, rhs);;
167
168
169 (********************************)
170
171 (* Replaces the given subexpression with the given reference index
172    in all interval functions in the list.
173    Returns the number of replaces and a new list of interval functions *)
174 let replace_subexpr expr expr_index f_list =
175   let rec replace f =
176     if f = expr then
177       1, Int_ref expr_index
178     else
179       match f with
180         | Int_pow (k, f1) ->
181             let c, f1' = replace f1 in
182               c, Int_pow (k, f1')
183         | Int_unary (tm, f1) ->
184             let c, f1' = replace f1 in
185               c, Int_unary (tm, f1')
186         | Int_binary (tm, f1, f2) ->
187             let c1, f1' = replace f1 in
188             let c2, f2' = replace f2 in
189               c1 + c2, Int_binary (tm, f1', f2')
190         | _ -> 0, f in
191   let cs, fs = unzip (map replace f_list) in
192     itlist (+) cs 0, fs;;
193
194
195                 
196 let is_leaf f =
197   match f with
198     | Int_pow _ -> false
199     | Int_unary _ -> false
200     | Int_binary _ -> false
201     | _ -> true;;
202
203 let find_and_replace_all f_list acc =
204   let rec find_and_replace f i f_list =
205     if is_leaf f then
206       f, (0, f_list)
207     else
208       let expr, (c, fs) =
209         match f with
210           | Int_pow (k, f1) -> find_and_replace f1 i f_list
211           | Int_unary (tm, f1) -> find_and_replace f1 i f_list
212           | Int_binary (tm, f1, f2) ->
213               let expr, (c1, fs) = find_and_replace f1 i f_list in
214                 if c1 > 1 then expr, (c1, fs) else find_and_replace f2 i f_list
215           | _ -> f, (0, f_list) in
216         if c > 1 then expr, (c, fs) else f, replace_subexpr f i f_list in
217     
218   let rec iterate fs acc =
219     let i = length acc in
220     let expr, (c, fs') = find_and_replace (hd fs) i fs in
221       if c > 1 then iterate fs' (acc @ [expr]) else fs, acc in
222
223   let rec iterate_all f_list ref_acc f_acc =
224     match f_list with
225       | [] -> f_acc, ref_acc
226       | f :: fs ->
227           let fs', acc' = iterate f_list ref_acc in
228             iterate_all (tl fs') acc' (f_acc @ [hd fs']) in
229
230     iterate_all f_list acc [];;
231
232
233 let eval_interval_fun_list pp (f_list, refs) vars =
234   let rec eval_refs refs acc =
235     match refs with
236       | [] -> acc
237       | r :: rs ->
238           let v = eval_interval_fun pp r vars acc in
239             eval_refs rs (acc @ [v]) in
240   let rs = eval_refs refs [] in
241     map (fun f -> eval_interval_fun pp f vars rs) f_list;;
242
243
244 (***************************************)
245
246 (* Approximate the bounds of the given interval with floating point numbers *)
247 let interval_to_float_interval =
248   let th = (UNDISCH_ALL o prove)(`interval_arith x (lo, hi) ==> 
249                                    interval_arith lo (a, y) ==>
250                                    interval_arith hi (z, b)
251                                    ==> interval_arith x (a, b)`,
252                                  REWRITE_TAC[interval_arith] THEN REAL_ARITH_TAC) in
253     fun pp int_th ->
254       let x_tm, bounds = dest_interval_arith (concl int_th) in
255       let lo, hi = dest_pair bounds in
256       let f_lo = build_interval_fun lo and
257           f_hi = build_interval_fun hi in
258       let th_lo = eval_interval_fun pp f_lo [] [] and
259           th_hi = eval_interval_fun pp f_hi [] [] in
260       let a_tm, y_tm = (dest_pair o rand o concl) th_lo and
261           z_tm, b_tm = (dest_pair o rand o concl) th_hi in
262       let th1 = INST[x_tm, x_var_real; lo, lo_var_real; hi, hi_var_real;
263                      a_tm, a_var_real; y_tm, y_var_real;
264                      z_tm, z_var_real; b_tm, b_var_real] th in
265         (MY_PROVE_HYP int_th o MY_PROVE_HYP th_lo o MY_PROVE_HYP th_hi) th1;;
266                      
267
268 (* Adds a new constant approximation to the table of constants *)
269 let add_constant_interval int_th =
270   let c_tm, _ = dest_interval_arith (concl int_th) in
271   let _ = is_const c_tm or failwith "add_constant_interval: not a constant" in
272   let th = interval_to_float_interval 20 int_th in
273   let approx_array = Array.init 20 (fun i -> float_interval_round i th) in
274     Hashtbl.add interval_constants c_tm (fun pp -> approx_array.(pp));;
275
276
277
278 end;;