Update from HH
[Flyspeck/.git] / formal_ineqs / informal / informal_eval_interval.hl
1 (* =========================================================== *)
2 (* Informal interval evaluation of arithmetic expressions      *)
3 (* Author: Alexey Solovyev                                     *)
4 (* Date: 2012-10-27                                            *)
5 (* =========================================================== *)
6
7 needs "informal/informal_arith.hl";;
8
9 module Informal_eval_interval = struct
10
11 open Informal_interval;;
12 open Informal_float;;
13 open Informal_atn;;
14
15 (* Creates an interval approximation of the given decimal term *)
16 let mk_float_interval_decimal pp decimal_tm =
17   let n_tm, d_tm = dest_binary "DECIMAL" decimal_tm in
18   let n, d = dest_numeral n_tm, dest_numeral d_tm in
19   let n_int, d_int = mk_num_interval n, mk_num_interval d in
20     div_interval pp n_int d_int;;
21
22
23 (* Unary interval operations *)
24 let unary_interval_operations = 
25   let table = Hashtbl.create 10 in
26   let add = Hashtbl.add in
27     add table "real_neg" (fun pp -> neg_interval);
28     add table "real_inv" inv_interval;
29     add table "sqrt" sqrt_interval;
30     add table "atn" atn_interval;
31     add table "acs" acs_interval;
32     table;;
33
34
35 (* Binary interval operations *)
36 let binary_interval_operations = 
37   let table = Hashtbl.create 10 in
38   let add = Hashtbl.add in
39     add table "real_add" add_interval;
40     add table "real_sub" sub_interval;
41     add table "real_mul" mul_interval;
42     add table "real_div" div_interval;
43     table;;
44
45
46 (* Interval approximations of constants *)
47 let interval_constants =
48   let table = Hashtbl.create 10 in
49   let add = Hashtbl.add in
50     add table "pi" (fun pp -> pi_approx_array.(pp));
51     table;;
52
53
54
55 (* Type of an interval function *)
56 type interval_fun =
57   | Int_ref of int
58   | Int_var of int
59   | Int_const of interval
60   | Int_decimal_const of term
61   | Int_named_const of string
62   | Int_pow of int * interval_fun
63   | Int_unary of string * interval_fun
64   | Int_binary of string * interval_fun * interval_fun;;
65
66
67 (* Equality of interval functions *)
68 let rec eq_ifun ifun1 ifun2 =
69   match (ifun1, ifun2) with
70     | (Int_ref r1, Int_ref r2) -> r1 = r2
71     | (Int_var v1, Int_var v2) -> v1 = v2
72     | (Int_decimal_const tm1, Int_decimal_const tm2) -> tm1 = tm2
73     | (Int_named_const name1, Int_named_const name2) -> name1 = name2
74     | (Int_pow (n1, f1), Int_pow (n2, f2)) -> n1 = n2 && eq_ifun f1 f2
75     | (Int_unary (op1, f1), Int_unary (op2, f2)) -> op1 = op2 && eq_ifun f1 f2
76     | (Int_binary (op1, f1, g1), Int_binary (op2, f2, g2)) -> op1 = op2 && eq_ifun f1 f2 && eq_ifun g1 g2
77     | (Int_const int1, Int_const int2) ->
78         let lo1, hi1 = dest_interval int1 and
79             lo2, hi2 = dest_interval int2 in
80           eq_float lo1 lo2 && eq_float hi1 hi2
81     | _ -> false;;
82
83
84 (* Evaluates the given interval function at the point
85    defined by the given list of variables *)
86 let eval_interval_fun =
87   let u_find = Hashtbl.find unary_interval_operations and
88       b_find = Hashtbl.find binary_interval_operations and
89       c_find = Hashtbl.find interval_constants in
90     fun pp ifun vars refs ->
91       let rec rec_eval f =
92         match f with
93           | Int_ref i -> List.nth refs i
94           | Int_var i -> List.nth vars (i - 1)
95           | Int_const int -> int
96           | Int_decimal_const tm -> mk_float_interval_decimal pp tm
97           | Int_named_const name -> (c_find name) pp
98           | Int_pow (n,f1) -> pow_interval pp n (rec_eval f1)
99           | Int_unary (op,f1) -> (u_find op) pp (rec_eval f1)
100           | Int_binary (op,f1,f2) -> (b_find op) pp (rec_eval f1) (rec_eval f2) in
101         rec_eval ifun;;
102
103
104 (* Evaluates all sub-expressions involving constants in the given interval function *)
105 let eval_constants =
106   let u_find = Hashtbl.find unary_interval_operations and
107       b_find = Hashtbl.find binary_interval_operations and
108       c_find = Hashtbl.find interval_constants in
109     fun pp ifun ->
110       let rec rec_eval f =
111         match f with
112           | Int_decimal_const tm -> Int_const (mk_float_interval_decimal pp tm)
113           | Int_named_const name -> Int_const (c_find name pp)
114           | Int_pow (n, f1) -> 
115               (let f1_val = rec_eval f1 in
116                  match f1_val with
117                    | Int_const int -> Int_const (pow_interval pp n int)
118                    | _ -> Int_pow (n,f1_val))
119           | Int_unary (op, f1) ->
120               (let f1_val = rec_eval f1 in
121                  match f1_val with
122                    | Int_const int -> Int_const (u_find op pp int)
123                    | _ -> Int_unary (op, f1_val))
124           | Int_binary (op, f1, f2) ->
125               (let f1_val, f2_val = rec_eval f1, rec_eval f2 in
126                  match f1_val with
127                    | Int_const int1 ->
128                        (match f2_val with
129                           | Int_const int2 -> Int_const (b_find op pp int1 int2)
130                           | _ -> Int_binary (op, f1_val, f2_val))
131                    | _ -> Int_binary (op, f1_val, f2_val))
132           | _ -> f in
133         rec_eval ifun;;
134
135
136
137 (**************************************)
138
139 (* Builds an interval function from the given term expression *)
140 let build_interval_fun =
141   let amp_op_real = `(&):num -> real` in
142   let rec rec_build expr_tm =
143     if is_const expr_tm then
144       (* Constant *)
145       Int_named_const (fst (dest_const expr_tm))
146     else if is_var expr_tm then
147       (* Variables should be of the form name$i *)
148       failwith ("Variables should be of the form name$i: " ^ string_of_term expr_tm)
149     else
150       let ltm, r_tm = dest_comb expr_tm in
151         (* Unary operations *)
152         if is_const ltm then
153           (* & *)
154           if ltm = amp_op_real then
155             let n = dest_numeral r_tm in
156               Int_const (mk_num_interval n)
157           else 
158             let r_fun = rec_build r_tm in
159               Int_unary ((fst o dest_const) ltm, r_fun)
160         else
161           (* Binary operations *)
162           let op, l_tm = dest_comb ltm in
163           let name = (fst o dest_const) op in
164             if name = "DECIMAL" then
165               (* DECIMAL *)
166               Int_decimal_const expr_tm
167             else if name = "real_pow" then
168               (* pow *)
169               let n = dest_small_numeral r_tm in
170                 Int_pow (n, rec_build l_tm)
171             else if name = "$" then
172               (* $ *)
173               Int_var (dest_small_numeral (rand expr_tm))
174             else
175               let lhs = rec_build l_tm and
176                   rhs = rec_build r_tm in
177                 Int_binary ((fst o dest_const) op, lhs, rhs) in
178     rec_build;;
179
180
181 (* Replaces the given subexpression with the given reference index
182    for all interval functions in the list.
183    Returns the number of replaces and a new list of interval functions *)
184 let replace_subexpr expr expr_index f_list =
185   let rec replace f =
186     if eq_ifun f expr then
187       1, Int_ref expr_index
188     else
189       match f with
190         | Int_pow (k, f1) ->
191             let c, f1' = replace f1 in
192               c, Int_pow (k, f1')
193         | Int_unary (op, f1) ->
194             let c, f1' = replace f1 in
195               c, Int_unary (op, f1')
196         | Int_binary (op, f1, f2) ->
197             let c1, f1' = replace f1 in
198             let c2, f2' = replace f2 in
199               c1 + c2, Int_binary (op, f1', f2')
200         | _ -> 0, f in
201   let cs, fs = unzip (map replace f_list) in
202     itlist (+) cs 0, fs;;
203
204
205                 
206 let is_leaf f =
207   match f with
208     | Int_pow _ -> false
209     | Int_unary _ -> false
210     | Int_binary _ -> false
211     | _ -> true;;
212
213 let find_and_replace_all f_list acc =
214   let rec find_and_replace f i f_list =
215     if is_leaf f then
216       f, (0, f_list)
217     else
218       let expr, (c, fs) =
219         match f with
220           | Int_pow (k, f1) -> find_and_replace f1 i f_list
221           | Int_unary (op, f1) -> find_and_replace f1 i f_list
222           | Int_binary (op, f1, f2) ->
223               let expr, (c1, fs) = find_and_replace f1 i f_list in
224                 if c1 > 1 then expr, (c1, fs) else find_and_replace f2 i f_list
225           | _ -> f, (0, f_list) in
226         if c > 1 then expr, (c, fs) else f, replace_subexpr f i f_list in
227     
228   let rec iterate fs acc =
229     let i = length acc in
230     let expr, (c, fs') = find_and_replace (hd fs) i fs in
231       if c > 1 then iterate fs' (acc @ [expr]) else fs, acc in
232
233   let rec iterate_all f_list ref_acc f_acc =
234     match f_list with
235       | [] -> f_acc, ref_acc
236       | f :: fs ->
237           let fs', acc' = iterate f_list ref_acc in
238             iterate_all (tl fs') acc' (f_acc @ [hd fs']) in
239
240     iterate_all f_list acc [];;
241
242
243 let eval_interval_fun_list pp (f_list, refs) vars =
244   let rec eval_refs refs acc =
245     match refs with
246       | [] -> acc
247       | r :: rs ->
248           let v = eval_interval_fun pp r vars acc in
249             eval_refs rs (acc @ [v]) in
250   let rs = eval_refs refs [] in
251     map (fun f -> eval_interval_fun pp f vars rs) f_list;;
252
253
254 (* Approximate the bounds of the given interval with floating point numbers *)
255 let interval_to_float_interval pp int_th =
256   let lo_tm, hi_tm = (dest_pair o rand o concl) int_th in
257   let f_lo = build_interval_fun lo_tm and
258       f_hi = build_interval_fun hi_tm in
259   let int_lo = eval_interval_fun pp f_lo [] [] and
260       int_hi = eval_interval_fun pp f_hi [] [] in
261   let a, _ = dest_interval int_lo and
262       _, b = dest_interval int_hi in
263     mk_interval (a, b);;
264                      
265
266 (* Adds a new constant approximation to the table of constants *)
267 let add_constant_interval int_th =
268   let c_tm = (rand o rator o concl) int_th in
269   let _ = is_const c_tm or failwith "add_constant_interval: not a constant" in
270   let interval = interval_to_float_interval 20 int_th in
271   let approx_array = Array.init 20 (fun i -> round_interval (if i = 0 then 1 else i) interval) in
272     Hashtbl.add interval_constants (fst (dest_const c_tm)) (fun pp -> approx_array.(pp));;
273
274         
275 end;;