1 (* =========================================================== *)
2 (* Formal interval evaluation of arithmetic expressions *)
3 (* Author: Alexey Solovyev *)
5 (* =========================================================== *)
7 needs "arith/more_float.hl";;
8 needs "arith/float_atn.hl";;
12 module Eval_interval = struct
21 (* Creates an interval approximation of the given decimal term *)
22 let mk_float_interval_decimal =
23 let DECIMAL' = SPEC_ALL DECIMAL in
25 let n_tm, d_tm = dest_binary "DECIMAL" decimal_tm in
26 let n, d = dest_numeral n_tm, dest_numeral d_tm in
27 let n_int, d_int = mk_float_interval_num n, mk_float_interval_num d in
28 let int = float_interval_div pp n_int d_int in
29 let eq_th = INST[n_tm, x_var_num; d_tm, y_var_num] DECIMAL' in
30 norm_interval int eq_th;;
33 (* Unary interval operations *)
34 let unary_interval_operations =
35 let table = Hashtbl.create 10 in
36 let add = Hashtbl.add in
37 add table `--` (fun pp -> float_interval_neg);
38 add table `inv` float_interval_inv;
39 add table `sqrt` float_interval_sqrt;
40 add table `atn` float_interval_atn;
41 add table `acs` float_interval_acs;
45 (* Binary interval operations *)
46 let binary_interval_operations =
47 let table = Hashtbl.create 10 in
48 let add = Hashtbl.add in
49 add table `+` float_interval_add;
50 add table `-` float_interval_sub;
51 add table `*` float_interval_mul;
52 add table `/` float_interval_div;
56 (* Interval approximations of constants *)
57 let interval_constants =
58 let table = Hashtbl.create 10 in
59 let add = Hashtbl.add in
60 add table `pi` (fun pp -> pi_approx_array.(pp));
65 (* Type of an interval function *)
70 | Int_decimal_const of term
71 | Int_named_const of term
72 | Int_pow of int * interval_fun
73 | Int_unary of term * interval_fun
74 | Int_binary of term * interval_fun * interval_fun;;
77 (* Evaluates the given interval function at the point
78 defined by the given list of variables *)
79 let eval_interval_fun pp ifun vars refs =
80 let u_find = Hashtbl.find unary_interval_operations and
81 b_find = Hashtbl.find binary_interval_operations and
82 c_find = Hashtbl.find interval_constants in
85 | Int_ref i -> List.nth refs i
86 | Int_var tm -> assoc tm vars
88 | Int_decimal_const tm -> mk_float_interval_decimal pp tm
89 | Int_named_const tm -> c_find tm pp
90 | Int_pow (n,f1) -> float_interval_pow_simple pp n (rec_eval f1)
91 | Int_unary (tm,f1) -> u_find tm pp (rec_eval f1)
92 | Int_binary (tm,f1,f2) -> b_find tm pp (rec_eval f1) (rec_eval f2) in
96 (* Evaluates all sub-expressions involving constants in the given interval function *)
97 let eval_constants pp ifun =
98 let u_find = Hashtbl.find unary_interval_operations and
99 b_find = Hashtbl.find binary_interval_operations and
100 c_find = Hashtbl.find interval_constants in
103 | Int_decimal_const tm -> Int_const (mk_float_interval_decimal pp tm)
104 | Int_named_const tm -> Int_const (c_find tm pp)
106 (let f1_val = rec_eval f1 in
108 | Int_const th -> Int_const (float_interval_pow_simple pp n th)
109 | _ -> Int_pow (n,f1_val))
110 | Int_unary (tm,f1) ->
111 (let f1_val = rec_eval f1 in
113 | Int_const th -> Int_const (u_find tm pp th)
114 | _ -> Int_unary (tm, f1_val))
115 | Int_binary (tm,f1,f2) ->
116 (let f1_val, f2_val = rec_eval f1, rec_eval f2 in
120 | Int_const th2 -> Int_const (b_find tm pp th1 th2)
121 | _ -> Int_binary (tm, f1_val, f2_val))
122 | _ -> Int_binary (tm, f1_val, f2_val))
128 (**************************************)
130 (* Builds an interval function from the given term *)
131 let rec build_interval_fun expr_tm =
132 if is_const expr_tm then
134 Int_named_const expr_tm
135 else if is_var expr_tm then
139 let ltm, r_tm = dest_comb expr_tm in
140 (* Unary operations *)
143 if ltm = amp_op_real then
144 let n = dest_numeral r_tm in
145 Int_const (mk_float_interval_num n)
147 let r_fun = build_interval_fun r_tm in
148 Int_unary (ltm, r_fun)
150 (* Binary operations *)
151 let op, l_tm = dest_comb ltm in
152 let name = (fst o dest_const) op in
153 if name = "DECIMAL" then
155 Int_decimal_const expr_tm
156 else if name = "real_pow" then
158 let n = dest_small_numeral r_tm in
159 Int_pow (n, build_interval_fun l_tm)
160 else if name = "$" then
164 let lhs = build_interval_fun l_tm and
165 rhs = build_interval_fun r_tm in
166 Int_binary (op, lhs, rhs);;
169 (********************************)
171 (* Replaces the given subexpression with the given reference index
172 in all interval functions in the list.
173 Returns the number of replaces and a new list of interval functions *)
174 let replace_subexpr expr expr_index f_list =
177 1, Int_ref expr_index
181 let c, f1' = replace f1 in
183 | Int_unary (tm, f1) ->
184 let c, f1' = replace f1 in
185 c, Int_unary (tm, f1')
186 | Int_binary (tm, f1, f2) ->
187 let c1, f1' = replace f1 in
188 let c2, f2' = replace f2 in
189 c1 + c2, Int_binary (tm, f1', f2')
191 let cs, fs = unzip (map replace f_list) in
192 itlist (+) cs 0, fs;;
199 | Int_unary _ -> false
200 | Int_binary _ -> false
203 let find_and_replace_all f_list acc =
204 let rec find_and_replace f i f_list =
210 | Int_pow (k, f1) -> find_and_replace f1 i f_list
211 | Int_unary (tm, f1) -> find_and_replace f1 i f_list
212 | Int_binary (tm, f1, f2) ->
213 let expr, (c1, fs) = find_and_replace f1 i f_list in
214 if c1 > 1 then expr, (c1, fs) else find_and_replace f2 i f_list
215 | _ -> f, (0, f_list) in
216 if c > 1 then expr, (c, fs) else f, replace_subexpr f i f_list in
218 let rec iterate fs acc =
219 let i = length acc in
220 let expr, (c, fs') = find_and_replace (hd fs) i fs in
221 if c > 1 then iterate fs' (acc @ [expr]) else fs, acc in
223 let rec iterate_all f_list ref_acc f_acc =
225 | [] -> f_acc, ref_acc
227 let fs', acc' = iterate f_list ref_acc in
228 iterate_all (tl fs') acc' (f_acc @ [hd fs']) in
230 iterate_all f_list acc [];;
233 let eval_interval_fun_list pp (f_list, refs) vars =
234 let rec eval_refs refs acc =
238 let v = eval_interval_fun pp r vars acc in
239 eval_refs rs (acc @ [v]) in
240 let rs = eval_refs refs [] in
241 map (fun f -> eval_interval_fun pp f vars rs) f_list;;
244 (***************************************)
246 (* Approximate the bounds of the given interval with floating point numbers *)
247 let interval_to_float_interval =
248 let th = (UNDISCH_ALL o prove)(`interval_arith x (lo, hi) ==>
249 interval_arith lo (a, y) ==>
250 interval_arith hi (z, b)
251 ==> interval_arith x (a, b)`,
252 REWRITE_TAC[interval_arith] THEN REAL_ARITH_TAC) in
254 let x_tm, bounds = dest_interval_arith (concl int_th) in
255 let lo, hi = dest_pair bounds in
256 let f_lo = build_interval_fun lo and
257 f_hi = build_interval_fun hi in
258 let th_lo = eval_interval_fun pp f_lo [] [] and
259 th_hi = eval_interval_fun pp f_hi [] [] in
260 let a_tm, y_tm = (dest_pair o rand o concl) th_lo and
261 z_tm, b_tm = (dest_pair o rand o concl) th_hi in
262 let th1 = INST[x_tm, x_var_real; lo, lo_var_real; hi, hi_var_real;
263 a_tm, a_var_real; y_tm, y_var_real;
264 z_tm, z_var_real; b_tm, b_var_real] th in
265 (MY_PROVE_HYP int_th o MY_PROVE_HYP th_lo o MY_PROVE_HYP th_hi) th1;;
268 (* Adds a new constant approximation to the table of constants *)
269 let add_constant_interval int_th =
270 let c_tm, _ = dest_interval_arith (concl int_th) in
271 let _ = is_const c_tm or failwith "add_constant_interval: not a constant" in
272 let th = interval_to_float_interval 20 int_th in
273 let approx_array = Array.init 20 (fun i -> float_interval_round i th) in
274 Hashtbl.add interval_constants c_tm (fun pp -> approx_array.(pp));;