Update from HH
[Flyspeck/.git] / formal_lp / old / formal_interval / eval_interval.hl
1 needs "../formal_lp/formal_interval/more_float.hl";;
2
3 let x_var_num = `x:num` and
4     y_var_num = `y:num`;;
5
6 let amp_op_real = `(&):num->real`;;
7
8 (* Creates an interval approximation of the given decimal term *)
9 let mk_float_interval_decimal =
10   let DECIMAL' = SPEC_ALL DECIMAL in
11     fun pp decimal_tm ->
12       let n_tm, d_tm = dest_binary "DECIMAL" decimal_tm in
13       let n, d = dest_numeral n_tm, dest_numeral d_tm in
14       let n_int, d_int = mk_float_interval_num n, mk_float_interval_num d in
15       let int = float_interval_div pp n_int d_int in
16       let eq_th = INST[n_tm, x_var_num; d_tm, y_var_num] DECIMAL' in
17         norm_interval int eq_th;;
18
19
20 (* Unary interval operations *)
21 let unary_interval_operations = 
22   let table = Hashtbl.create 10 in
23   let add = Hashtbl.add in
24     add table `--` (fun pp -> float_interval_neg);
25     add table `inv` float_interval_inv;
26     add table `sqrt` float_interval_sqrt;
27     add table `atn` float_interval_atn;
28     add table `acs` float_interval_acs;
29     table;;
30
31
32 (* Binary interval operations *)
33 let binary_interval_operations = 
34   let table = Hashtbl.create 10 in
35   let add = Hashtbl.add in
36     add table `+` float_interval_add;
37     add table `-` float_interval_sub;
38     add table `*` float_interval_mul;
39     add table `/` float_interval_div;
40     table;;
41
42
43 (* Interval approximations of constants *)
44 let interval_constants =
45   let table = Hashtbl.create 10 in
46   let add = Hashtbl.add in
47     add table `pi` (fun pp -> pi_approx_array.(pp));
48     table;;
49
50
51
52 (* Type of an interval function *)
53 type interval_fun =
54   | Int_ref of int
55   | Int_var of term
56   | Int_const of thm
57   | Int_decimal_const of term
58   | Int_named_const of term
59   | Int_pow of int * interval_fun
60   | Int_unary of term * interval_fun
61   | Int_binary of term * interval_fun * interval_fun;;
62
63
64 (* Evaluates the given interval function at the point
65    defined by the given list of variables *)
66 let eval_interval_fun pp ifun vars refs =
67   let u_find = Hashtbl.find unary_interval_operations and
68       b_find = Hashtbl.find binary_interval_operations and
69       c_find = Hashtbl.find interval_constants in
70   let rec rec_eval f =
71     match f with
72           | Int_ref i -> List.nth refs i
73       | Int_var tm -> assoc tm vars
74       | Int_const th -> th
75       | Int_decimal_const tm -> mk_float_interval_decimal pp tm
76       | Int_named_const tm -> c_find tm pp
77       | Int_pow (n,f1) -> float_interval_pow_simple pp n (rec_eval f1)
78       | Int_unary (tm,f1) -> u_find tm pp (rec_eval f1)
79       | Int_binary (tm,f1,f2) -> b_find tm pp (rec_eval f1) (rec_eval f2) in
80     rec_eval ifun;;
81
82
83 (* Evaluates all sub-expressions involving constants in the given interval function *)
84 let eval_constants pp ifun =
85   let u_find = Hashtbl.find unary_interval_operations and
86       b_find = Hashtbl.find binary_interval_operations and
87       c_find = Hashtbl.find interval_constants in
88   let rec rec_eval f =
89     match f with
90       | Int_decimal_const tm -> Int_const (mk_float_interval_decimal pp tm)
91       | Int_named_const tm -> Int_const (c_find tm pp)
92       | Int_pow (n,f1) -> 
93           (let f1_val = rec_eval f1 in
94              match f1_val with
95                | Int_const th -> Int_const (float_interval_pow_simple pp n th)
96                | _ -> Int_pow (n,f1_val))
97       | Int_unary (tm,f1) ->
98           (let f1_val = rec_eval f1 in
99              match f1_val with
100                | Int_const th -> Int_const (u_find tm pp th)
101                | _ -> Int_unary (tm, f1_val))
102       | Int_binary (tm,f1,f2) ->
103           (let f1_val, f2_val = rec_eval f1, rec_eval f2 in
104              match f1_val with
105                | Int_const th1 ->
106                    (match f2_val with
107                       | Int_const th2 -> Int_const (b_find tm pp th1 th2)
108                       | _ -> Int_binary (tm, f1_val, f2_val))
109                | _ -> Int_binary (tm, f1_val, f2_val))
110           | _ -> f in
111     rec_eval ifun;;
112
113
114
115 (**************************************)
116
117 (* Builds an interval function from the given term expression *)
118 let rec build_interval_fun expr_tm =
119   if is_const expr_tm then
120     (* Constant *)
121     Int_named_const expr_tm
122   else if is_var expr_tm then
123     (* Variable *)
124     Int_var expr_tm
125   else
126     let ltm, r_tm = dest_comb expr_tm in
127       (* Unary operations *)
128       if is_const ltm then
129         (* & *)
130         if ltm = amp_op_real then
131           let n = dest_numeral r_tm in
132             Int_const (mk_float_interval_num n)
133         else 
134           let r_fun = build_interval_fun r_tm in
135             Int_unary (ltm, r_fun)
136       else
137         (* Binary operations *)
138         let op, l_tm = dest_comb ltm in
139         let name = (fst o dest_const) op in
140           if name = "DECIMAL" then
141             (* DECIMAL *)
142             Int_decimal_const expr_tm
143           else if name = "real_pow" then
144             (* pow *)
145             let n = dest_small_numeral r_tm in
146               Int_pow (n, build_interval_fun l_tm)
147           else if name = "$" then
148             (* $ *)
149             Int_var expr_tm
150           else
151             let lhs = build_interval_fun l_tm and
152                 rhs = build_interval_fun r_tm in
153               Int_binary (op, lhs, rhs);;
154
155
156 (*
157 let test_vars = [`x:real`, two_interval];;
158 let f = build_interval_fun `(&1 + &3 * pi) + sqrt (#3.13525238353 * x)`;;
159 let f2 = eval_constants pp f;;
160 eval_interval_fun pp f test_vars;;
161 eval_interval_fun pp f2 test_vars;;
162
163 test 100 (eval_interval_fun pp f) test_vars;;
164 test 100 (eval_interval_fun pp f2) test_vars;;
165 *)
166
167 (* Replaces the given subexpression with the given reference index
168    in all interval functions in the list.
169    Returns the number of replaces and a new list of interval functions *)
170 let replace_subexpr expr expr_index f_list =
171   let rec replace f =
172     if f = expr then
173       1, Int_ref expr_index
174     else
175       match f with
176         | Int_pow (k, f1) ->
177             let c, f1' = replace f1 in
178               c, Int_pow (k, f1')
179         | Int_unary (tm, f1) ->
180             let c, f1' = replace f1 in
181               c, Int_unary (tm, f1')
182         | Int_binary (tm, f1, f2) ->
183             let c1, f1' = replace f1 in
184             let c2, f2' = replace f2 in
185               c1 + c2, Int_binary (tm, f1', f2')
186         | _ -> 0, f in
187   let cs, fs = unzip (map replace f_list) in
188     itlist (+) cs 0, fs;;
189
190
191                 
192 let is_leaf f =
193   match f with
194     | Int_pow _ -> false
195     | Int_unary _ -> false
196     | Int_binary _ -> false
197     | _ -> true;;
198
199 let find_and_replace_all f_list acc =
200   let rec find_and_replace f i f_list =
201     if is_leaf f then
202       f, (0, f_list)
203     else
204       let expr, (c, fs) =
205         match f with
206           | Int_pow (k, f1) -> find_and_replace f1 i f_list
207           | Int_unary (tm, f1) -> find_and_replace f1 i f_list
208           | Int_binary (tm, f1, f2) ->
209               let expr, (c1, fs) = find_and_replace f1 i f_list in
210                 if c1 > 1 then expr, (c1, fs) else find_and_replace f2 i f_list
211           | _ -> f, (0, f_list) in
212         if c > 1 then expr, (c, fs) else f, replace_subexpr f i f_list in
213     
214   let rec iterate fs acc =
215     let i = length acc in
216     let expr, (c, fs') = find_and_replace (hd fs) i fs in
217       if c > 1 then iterate fs' (acc @ [expr]) else fs, acc in
218
219   let rec iterate_all f_list ref_acc f_acc =
220     match f_list with
221       | [] -> f_acc, ref_acc
222       | f :: fs ->
223           let fs', acc' = iterate f_list ref_acc in
224             iterate_all (tl fs') acc' (f_acc @ [hd fs']) in
225
226     iterate_all f_list acc [];;
227
228
229 let eval_interval_fun_list pp (f_list, refs) vars =
230   let rec eval_refs refs acc =
231     match refs with
232       | [] -> acc
233       | r :: rs ->
234           let v = eval_interval_fun pp r vars acc in
235             eval_refs rs (acc @ [v]) in
236   let rs = eval_refs refs [] in
237     map (fun f -> eval_interval_fun pp f vars rs) f_list;;
238
239
240 (*
241 let pp = 5;;
242 let test_vars = [`x:real`, two_interval];;
243 let test_expr1 = `x * x + (&3 * x) * x * x + &3 * x + &3 * x`;;
244 let test_expr2 = `(x * x) * (x * &2) + x * &2`;;
245 let subexpr1 = `x * x` and subexpr2 = `&3 * x`;;
246 let test_f1 = build_interval_fun test_expr1 and
247     test_f2 = build_interval_fun test_expr2;;
248 let sub1 = build_interval_fun subexpr1 and sub2 = build_interval_fun subexpr2;;
249
250
251 let v = find_and_replace_all [test_f1; test_f2] [];;
252
253 eval_interval_fun_list pp v test_vars;;
254 map (fun f -> eval_interval_fun pp f test_vars []) [test_f1; test_f2];;
255
256 (* 0.712 *)
257 test 100 (map (fun f -> eval_interval_fun pp f test_vars [])) [test_f1; test_f2];;
258 (* 0.432 *)
259 test 100 (eval_interval_fun_list pp v) test_vars;;
260 *)
261