Update from HH
[Flyspeck/.git] / formal_lp / old / arith / informal / informal_eval_interval.hl
1 needs "../formal_lp/arith/informal/informal_arith.hl";;
2
3 module Informal_eval_interval = struct
4
5 open Informal_interval;;
6 open Informal_float;;
7 open Informal_atn;;
8
9 (* Creates an interval approximation of the given decimal term *)
10 let mk_float_interval_decimal pp decimal_tm =
11   let n_tm, d_tm = dest_binary "DECIMAL" decimal_tm in
12   let n, d = dest_numeral n_tm, dest_numeral d_tm in
13   let n_int, d_int = mk_num_interval n, mk_num_interval d in
14     div_interval pp n_int d_int;;
15
16
17 (* Unary interval operations *)
18 let unary_interval_operations = 
19   let table = Hashtbl.create 10 in
20   let add = Hashtbl.add in
21     add table "real_neg" (fun pp -> neg_interval);
22     add table "real_inv" inv_interval;
23     add table "sqrt" sqrt_interval;
24     add table "atn" atn_interval;
25     add table "acs" acs_interval;
26     table;;
27
28
29 (* Binary interval operations *)
30 let binary_interval_operations = 
31   let table = Hashtbl.create 10 in
32   let add = Hashtbl.add in
33     add table "real_add" add_interval;
34     add table "real_sub" sub_interval;
35     add table "real_mul" mul_interval;
36     add table "real_div" div_interval;
37     table;;
38
39
40 (* Interval approximations of constants *)
41 let interval_constants =
42   let table = Hashtbl.create 10 in
43   let add = Hashtbl.add in
44     add table "pi" (fun pp -> pi_approx_array.(pp));
45     table;;
46
47
48
49 (* Type of an interval function *)
50 type interval_fun =
51   | Int_ref of int
52   | Int_var of int
53   | Int_const of interval
54   | Int_decimal_const of term
55   | Int_named_const of string
56   | Int_pow of int * interval_fun
57   | Int_unary of string * interval_fun
58   | Int_binary of string * interval_fun * interval_fun;;
59
60
61 (* Equality of interval functions *)
62 let rec eq_ifun ifun1 ifun2 =
63   match (ifun1, ifun2) with
64     | (Int_ref r1, Int_ref r2) -> r1 = r2
65     | (Int_var v1, Int_var v2) -> v1 = v2
66     | (Int_decimal_const tm1, Int_decimal_const tm2) -> tm1 = tm2
67     | (Int_named_const name1, Int_named_const name2) -> name1 = name2
68     | (Int_pow (n1, f1), Int_pow (n2, f2)) -> n1 = n2 && eq_ifun f1 f2
69     | (Int_unary (op1, f1), Int_unary (op2, f2)) -> op1 = op2 && eq_ifun f1 f2
70     | (Int_binary (op1, f1, g1), Int_binary (op2, f2, g2)) -> op1 = op2 && eq_ifun f1 f2 && eq_ifun g1 g2
71     | (Int_const int1, Int_const int2) ->
72         let lo1, hi1 = dest_interval int1 and
73             lo2, hi2 = dest_interval int2 in
74           eq_float lo1 lo2 && eq_float hi1 hi2
75     | _ -> false;;
76
77
78 (* Evaluates the given interval function at the point
79    defined by the given list of variables *)
80 let eval_interval_fun =
81   let u_find = Hashtbl.find unary_interval_operations and
82       b_find = Hashtbl.find binary_interval_operations and
83       c_find = Hashtbl.find interval_constants in
84     fun pp ifun vars refs ->
85       let rec rec_eval f =
86         match f with
87           | Int_ref i -> List.nth refs i
88           | Int_var i -> List.nth vars (i - 1)
89           | Int_const int -> int
90           | Int_decimal_const tm -> mk_float_interval_decimal pp tm
91           | Int_named_const name -> (c_find name) pp
92           | Int_pow (n,f1) -> pow_interval pp n (rec_eval f1)
93           | Int_unary (op,f1) -> (u_find op) pp (rec_eval f1)
94           | Int_binary (op,f1,f2) -> (b_find op) pp (rec_eval f1) (rec_eval f2) in
95         rec_eval ifun;;
96
97
98 (* Evaluates all sub-expressions involving constants in the given interval function *)
99 let eval_constants =
100   let u_find = Hashtbl.find unary_interval_operations and
101       b_find = Hashtbl.find binary_interval_operations and
102       c_find = Hashtbl.find interval_constants in
103     fun pp ifun ->
104       let rec rec_eval f =
105         match f with
106           | Int_decimal_const tm -> Int_const (mk_float_interval_decimal pp tm)
107           | Int_named_const name -> Int_const (c_find name pp)
108           | Int_pow (n, f1) -> 
109               (let f1_val = rec_eval f1 in
110                  match f1_val with
111                    | Int_const int -> Int_const (pow_interval pp n int)
112                    | _ -> Int_pow (n,f1_val))
113           | Int_unary (op, f1) ->
114               (let f1_val = rec_eval f1 in
115                  match f1_val with
116                    | Int_const int -> Int_const (u_find op pp int)
117                    | _ -> Int_unary (op, f1_val))
118           | Int_binary (op, f1, f2) ->
119               (let f1_val, f2_val = rec_eval f1, rec_eval f2 in
120                  match f1_val with
121                    | Int_const int1 ->
122                        (match f2_val with
123                           | Int_const int2 -> Int_const (b_find op pp int1 int2)
124                           | _ -> Int_binary (op, f1_val, f2_val))
125                    | _ -> Int_binary (op, f1_val, f2_val))
126           | _ -> f in
127         rec_eval ifun;;
128
129
130
131 (**************************************)
132
133 (* Builds an interval function from the given term expression *)
134 let build_interval_fun =
135   let amp_op_real = `(&):num -> real` in
136   let rec rec_build expr_tm =
137     if is_const expr_tm then
138       (* Constant *)
139       Int_named_const (fst (dest_const expr_tm))
140     else if is_var expr_tm then
141       (* Variables should be of the form name$i *)
142       failwith ("Variables should be of the form name$i: " ^ string_of_term expr_tm)
143     else
144       let ltm, r_tm = dest_comb expr_tm in
145         (* Unary operations *)
146         if is_const ltm then
147           (* & *)
148           if ltm = amp_op_real then
149             let n = dest_numeral r_tm in
150               Int_const (mk_num_interval n)
151           else 
152             let r_fun = rec_build r_tm in
153               Int_unary ((fst o dest_const) ltm, r_fun)
154         else
155           (* Binary operations *)
156           let op, l_tm = dest_comb ltm in
157           let name = (fst o dest_const) op in
158             if name = "DECIMAL" then
159               (* DECIMAL *)
160               Int_decimal_const expr_tm
161             else if name = "real_pow" then
162               (* pow *)
163               let n = dest_small_numeral r_tm in
164                 Int_pow (n, rec_build l_tm)
165             else if name = "$" then
166               (* $ *)
167               Int_var (dest_small_numeral (rand expr_tm))
168             else
169               let lhs = rec_build l_tm and
170                   rhs = rec_build r_tm in
171                 Int_binary ((fst o dest_const) op, lhs, rhs) in
172     rec_build;;
173
174
175 (* Replaces the given subexpression with the given reference index
176    for all interval functions in the list.
177    Returns the number of replaces and a new list of interval functions *)
178 let replace_subexpr expr expr_index f_list =
179   let rec replace f =
180     if eq_ifun f expr then
181       1, Int_ref expr_index
182     else
183       match f with
184         | Int_pow (k, f1) ->
185             let c, f1' = replace f1 in
186               c, Int_pow (k, f1')
187         | Int_unary (op, f1) ->
188             let c, f1' = replace f1 in
189               c, Int_unary (op, f1')
190         | Int_binary (op, f1, f2) ->
191             let c1, f1' = replace f1 in
192             let c2, f2' = replace f2 in
193               c1 + c2, Int_binary (op, f1', f2')
194         | _ -> 0, f in
195   let cs, fs = unzip (map replace f_list) in
196     itlist (+) cs 0, fs;;
197
198
199                 
200 let is_leaf f =
201   match f with
202     | Int_pow _ -> false
203     | Int_unary _ -> false
204     | Int_binary _ -> false
205     | _ -> true;;
206
207 let find_and_replace_all f_list acc =
208   let rec find_and_replace f i f_list =
209     if is_leaf f then
210       f, (0, f_list)
211     else
212       let expr, (c, fs) =
213         match f with
214           | Int_pow (k, f1) -> find_and_replace f1 i f_list
215           | Int_unary (op, f1) -> find_and_replace f1 i f_list
216           | Int_binary (op, f1, f2) ->
217               let expr, (c1, fs) = find_and_replace f1 i f_list in
218                 if c1 > 1 then expr, (c1, fs) else find_and_replace f2 i f_list
219           | _ -> f, (0, f_list) in
220         if c > 1 then expr, (c, fs) else f, replace_subexpr f i f_list in
221     
222   let rec iterate fs acc =
223     let i = length acc in
224     let expr, (c, fs') = find_and_replace (hd fs) i fs in
225       if c > 1 then iterate fs' (acc @ [expr]) else fs, acc in
226
227   let rec iterate_all f_list ref_acc f_acc =
228     match f_list with
229       | [] -> f_acc, ref_acc
230       | f :: fs ->
231           let fs', acc' = iterate f_list ref_acc in
232             iterate_all (tl fs') acc' (f_acc @ [hd fs']) in
233
234     iterate_all f_list acc [];;
235
236
237 let eval_interval_fun_list pp (f_list, refs) vars =
238   let rec eval_refs refs acc =
239     match refs with
240       | [] -> acc
241       | r :: rs ->
242           let v = eval_interval_fun pp r vars acc in
243             eval_refs rs (acc @ [v]) in
244   let rs = eval_refs refs [] in
245     map (fun f -> eval_interval_fun pp f vars rs) f_list;;
246
247
248 end;;
249
250
251 (*
252 (* Tests *)
253 needs "../formal_lp/formal_interval/eval_interval.hl";;
254
255 let pp = 7;;
256
257 let var_tm = `&1 / #7.1`;;
258 let var = eval_interval_fun pp (build_interval_fun var_tm) [] [];;
259 let var0 = Informal_eval_interval.eval_interval_fun pp 
260   (Informal_eval_interval.build_interval_fun var_tm) [] [];;
261
262 let test_vars = [`(x:real^1)$1`, var];;
263 let test_vars0 = [var0];;
264 let test_expr1 = `(x:real^1)$1 * x$1 + (&3 * x$1) * x$1 * x$1 + &3 * x$1 + &3 * x$1`;;
265 let test_expr2 = `((x:real^1)$1 * x$1) * (x$1 * &2) + x$1 * &2`;;
266 let subexpr1 = `(x:real^1)$1 * x$1` and subexpr2 = `&3 * (x:real^1)$1`;;
267
268 let test_f1 = build_interval_fun test_expr1 and
269     test_f2 = build_interval_fun test_expr2;;
270 let sub1 = build_interval_fun subexpr1 and sub2 = build_interval_fun subexpr2;;
271
272 let f1, f2, s1, s2 =
273   let build = Informal_eval_interval.build_interval_fun in
274     build test_expr1, build test_expr2, build subexpr1, build subexpr2;;
275
276 let v = find_and_replace_all [test_f1; test_f2] [];;
277 let v0 = Informal_eval_interval.find_and_replace_all [f1; f2] [];;
278
279 let test_dest int =
280   let lo, hi = Informal_interval.dest_interval int in
281     Informal_float.dest_float lo, Informal_float.dest_float hi;;
282
283 eval_interval_fun_list pp v test_vars;;
284 map test_dest (Informal_eval_interval.eval_interval_fun_list pp v0 test_vars0);;
285
286 map (fun f -> eval_interval_fun pp f test_vars []) [test_f1; test_f2];;
287 let r = map (fun f -> Informal_eval_interval.eval_interval_fun pp f test_vars0 []) [f1; f2];;
288 map test_dest r;;
289
290
291 (* 0.712 *)
292 test 100 (map (fun f -> eval_interval_fun pp f test_vars [])) [test_f1; test_f2];;
293 (* 0.432 *)
294 test 100 (eval_interval_fun_list pp v) test_vars;;
295
296 test 100 (map (fun f -> Informal_eval_interval.eval_interval_fun pp f test_vars0 [])) [f1; f2];;
297 test 100 (Informal_eval_interval.eval_interval_fun_list pp v0) test_vars0;;
298 *)