Update from HH
[Flyspeck/.git] / formal_ineqs / verifier / interval_m / taylor.ml
1 (* =========================================================== *)
2 (* OCaml taylor intervals                                      *)
3 (* Author: Thomas C. Hales                                     *)
4 (* Date: 2011-08-21                                            *)
5 (* Modified: Alexey Solovyev, 2012-10-27                       *)
6 (* =========================================================== *)
7
8 (* port of taylor functions, taylor interval *)
9
10 (*
11 The first part of the file implements basic operations on type taylor_interval.
12
13 Then a type tfunction is defined that represents a twice continuously
14 differentiable function of six variables.  It can be evaluated, which
15 is the taylor_interval data associated with it.
16
17 Sometimes a tfunction f is used to represent an inequality f < 0.
18 (See recurse.hl.
19 *)
20
21 needs "verifier/interval_m/line_interval.ml";;
22 needs "verifier/interval_m/univariate.ml";;
23
24 module Taylor = struct
25
26 open Interval_types;;
27 open Interval;;
28 open Univariate;;
29 open Line_interval;;
30
31
32 (* general utilities *)
33
34 let m8_sum =
35   let ( + ) = iadd in
36     fun dd1 dd2 ->
37       let r8_sum (x,y) = table (fun i -> mth x i + mth y i)  in
38         map r8_sum (zip dd1 dd2);;
39
40 let center_form(x,z) =
41   let ( + ) , ( - ), ( / ) = up(); upadd,upsub,updiv in
42   let y = table (fun i -> if (mth x i=mth z i) then mth x i else (mth x i + mth z i)/ 2.0)  in
43   let w = table (fun i -> max (mth z i - mth y i) (mth y i - mth x i))  in
44   let _ = (minl w >= 0.0) or failwith "centerform" in
45      (y,w);;
46
47 (* start with taylor interval operations *)
48
49 let make_taylor_interval (l1,w1,dd1) = {l = l1; w = w1; dd=dd1;};;
50
51 let ti_add (ti1,ti2) =
52   let _ = (ti1.w = ti2.w) or failwith ("width mismatch in ti") in
53     make_taylor_interval( ladd ti1.l ti2.l,ti1.w, m8_sum ti1.dd ti2.dd);;
54
55 let ti_scale (ti,t) =
56    make_taylor_interval( smul ti.l t,ti.w,  table2 (fun i j ->  imul (mth2 ti.dd i j) t));;
57
58
59
60 let taylor_error ti =
61   let ( + ), ( * ) , ( / )= up(); upadd, upmul, updiv in
62   let dot_abs_row r = List.fold_left2 (fun a b c -> a + b * iabs c) 0.0 ti.w r in
63   let dots = map dot_abs_row (ti.dd) in
64     (List.fold_left2 (fun a b c -> a + b * c) 0.0 ti.w dots) / 2.0;;
65 (*  (end_itlist ( + ) p) / 2.0 ;; *)
66
67 let upper_bound ti = 
68   let e = taylor_error ti in
69   let ( + ), ( * ) = up(); upadd, upmul in
70   let t = ti.l.f.hi + e in
71     t + List.fold_left2 (fun a b c -> a + b * iabs c) 0.0 ti.w ti.l.df;;
72
73 let lower_bound ti = 
74   let e = taylor_error ti in
75   let ( + ), ( * ),(- ) = down(); downadd,downmul,downsub in
76   let t = ti.l.f.lo - e in
77     t + List.fold_left2 (fun a b c -> a + ( ~-. b) * iabs c) 0.0 ti.w ti.l.df;;
78
79 let upper_partial ti i = 
80   let ( + ), ( * ) =   up(); upadd,upmul in
81     let err = List.fold_left2 (fun a b c -> a + b*(max c.hi (~-. (c.lo)))) 
82       0.0 ti.w (mth ti.dd i) in
83       err + Interval.sup ( mth ti.l.df i);;
84
85 let lower_partial ti i = 
86   let ( + ), ( * ), ( - ) = down();downadd,downmul,downsub in
87     let err = List.fold_left2 (fun a b c -> a + b * min c.lo (~-. (c.hi))) 
88       0.0 ti.w (mth ti.dd i) in
89       Interval.inf ( mth ti.l.df i) + err;;
90
91
92 let ti_mul (ti1,ti2) =
93   let _ = (ti1.w = ti2.w) or failwith ("ti_mul: width mismatch in ti") in
94   let line = lmul ti1.l ti2.l in
95   let f1_int =
96     let lo, hi = lower_bound ti1, upper_bound ti1 in mk_interval (lo, hi) in
97   let f2_int =
98     let lo, hi = lower_bound ti2, upper_bound ti2 in mk_interval (lo, hi) in
99   let d1_ints = table (fun i -> mk_interval (lower_partial ti1 i, upper_partial ti1 i)) in
100   let d2_ints = table (fun i -> mk_interval (lower_partial ti2 i, upper_partial ti2 i)) in
101   let dd = table2 (fun i j ->
102                      let ( + ), ( * ) = iadd, imul in
103                        mth2 ti1.dd i j * f2_int + mth d1_ints i * mth d2_ints j +
104                          mth d1_ints j * mth d2_ints i + f1_int * mth2 ti2.dd i j) in
105     make_taylor_interval(line, ti1.w, dd);;
106     
107
108
109 (* primitive A *)
110
111 type primitiveA = {
112   f_df : int -> float list -> float list -> interval;
113   hfn : float list -> line;
114   second : float list -> float list -> interval list list;
115 };;
116
117 let make_primitiveA (f,h1,s1) = {f_df = f; hfn = h1; second = s1; };;
118
119 let unitA = 
120   let zero2 = table2 (fun i j -> zero) in
121     make_primitiveA (
122       (fun i x z -> if i = 0 then one else zero),
123       (fun y -> line_unit),
124       (fun x z -> zero2)
125 );;
126
127 let evalf4A pA w x y z =
128   make_taylor_interval(
129     pA.hfn y,
130     w,
131     pA.second x z
132   );;
133
134 let line_estimateA pA y = pA.hfn y;;
135
136 (* primitive U *)
137
138 type primitiveU = {
139   slot: int;
140   uv: univariate;
141 };;
142
143 let mk_primitiveU s1 uv1 = 
144   let _ = (s1 < 8) or failwith (Printf.sprintf "slot %d" s1) in
145     { slot = s1; uv = uv1; };;
146
147 let line_estimateU p y = 
148   let y0 = mth y p.slot in
149   let t = mk_interval(y0,y0) in
150   let d = table (fun i -> if (i=p.slot) then eval p.uv t 1 else zero)  in
151     mk_line (    eval p.uv t 0,    d  );;
152
153 let evalf4U =
154   let row0 = table (fun i -> zero)  in
155     fun p w x y z ->
156       let t = mk_interval(mth x p.slot,mth z p.slot) in
157       let row_slot = table  (fun i -> if (i=p.slot) then eval p.uv t 2 else zero)  in
158       let dd = table (fun i -> if (i=p.slot) then row_slot else row0)  in
159       make_taylor_interval(
160         line_estimateU p y,
161         w,
162         dd
163       );;
164
165 type tfunction = 
166   | Prim_a of primitiveA
167   | Uni of primitiveU
168   | Plus of tfunction * tfunction
169   | Product of tfunction * tfunction
170   | Scale of tfunction * interval
171   | Uni_compose of univariate * tfunction
172   | Composite of tfunction *  (* F(g1,g2,g3,g4,g5,g6,g7,g8) *)
173       tfunction *tfunction *tfunction *
174       tfunction *tfunction *tfunction *
175       tfunction *tfunction;;
176
177 let unit = Prim_a unitA;;
178
179 let x1 = Uni (mk_primitiveU 0 ux1);;
180 let x2 = Uni (mk_primitiveU 1 ux1);;
181 let x3 = Uni (mk_primitiveU 2 ux1);;
182 let x4 = Uni (mk_primitiveU 3 ux1);;
183 let x5 = Uni (mk_primitiveU 4 ux1);;
184 let x6 = Uni (mk_primitiveU 5 ux1);;
185
186
187 let x1x2 = 
188   let tab2 = table2 (fun i j -> if (i+j=1) then one else zero) in
189     Prim_a (make_primitiveA(
190               (fun i x z -> 
191                  let x1 = mk_interval (mth x 0, mth z 0) in
192                  let x2 = mk_interval (mth x 1, mth z 1) in
193                    if i = 0 then imul x1 x2
194                    else if i = 1 then x2
195                    else if i = 2 then x1
196                    else zero),
197               (fun y -> 
198                  let u1 = mth y 0 in 
199                  let u2 = mth y 1 in
200                  let x1 = mk_interval(u1,u1) in
201                  let x2 = mk_interval(u2,u2) in
202                    mk_line(
203                      imul x1 x2,
204                      table (fun i -> if i=0 then x2 else if i=1 then x1 else zero)
205                    )),
206               (fun x z -> tab2)));;
207
208 let tf_product tf1 tf2 = Composite(x1x2,tf1,tf2,unit,unit,unit,unit,unit,unit);;
209   
210
211 (* This is one of the most difficult functions in the interval code.
212    It uses the chain rule to compute the second partial derivatives with
213    respect to x(i) x(j), of a function composition
214
215    F(x1,...,x6) = f(g1(x1,...x6),g2(x1,...x6),...,g6(x1,...x6)).
216
217    (F i j) = sum {k m} (f k m) (gk i) (gm j)     + sum {r} (f r) (gr i j).
218
219    Fast performance of this function is very important, especially
220    when many of the functions g* are constant.
221    There is a bit of imperative programming here, in computing the sums.
222
223    Note that ( + ) and ( * ) have different types in various subsections.
224 *)
225
226 let eval_composite =
227   let rest = () in
228   let  sparse_table h f = filter h (List.flatten (table2 f)) in
229     fun hdr p1 p2 p3 p4 p5 p6 p7 p8 w ->
230       let p = [p1;p2;p3;p4;p5;p6;p7;p8] in
231         (* wide and narrow ranges of p *)
232       let (aw,bw) = map (lower_bound) p, map (upper_bound) p  in 
233       let (a,b) = map (fun p -> p.l.f.lo) p, map (fun p -> p.l.f.hi) p in 
234         (* wide and narrow widths from a to b *)
235       let (u,wu,wf) = 
236         let ( + ),( - ),( / ) = up();upadd,upsub,updiv in
237         let u = table (fun i -> (mth a i + mth b i) / 2.0)  in
238         let wu = table (fun i -> max (mth bw i - mth u i) (mth u i - mth aw i))  in
239         let wf = table (fun i -> max (mth b i - mth u i) (mth u i - mth a i))  in
240           (u,wu,wf) in
241       let (fu:taylor_interval) = hdr wu aw u bw in
242       let fpy = 
243         let t = make_taylor_interval(fu.l,wf,fu.dd) in
244           mk_line (
245             mk_interval(lower_bound t, upper_bound t),
246             table (fun i -> mk_interval(lower_partial t i,upper_partial t i))  ) in
247         (* use chain rule imperatively to compute narrow first derivative *)
248       let df_tmp = Array.create 8 zero in
249       let ( + ) = iadd in
250       let ( * ) = imul in
251       let _ = for j=0 to 7 do 
252         let dfj = mth fpy.df j in
253           if is_zero dfj then rest 
254           else for i=0 to 7 do
255             let r = mth (mth p j).l.df i in
256               if (is_zero r) then rest else df_tmp.(i) <- df_tmp.(i) + dfj * r;
257           done;
258       done in
259       let lin = mk_line (       fpy.f, Array.to_list df_tmp ) in
260         (* second derivative init *)
261       let fW_partial = table (fun i -> mk_interval(lower_partial fu i,upper_partial fu i)) in
262       let pW_partial = sparse_table (fun (_,_,z) ->not (is_zero z))  
263         (fun k i -> (k,i,(mk_interval(lower_partial (mth p k) i,upper_partial (mth p k) i)))) in
264         (* chain rule 4-nested loop!, but flattened with sparse table *)
265       let dcw = Array.make_matrix 8 8 zero in 
266       let _ = for i=0 to 7 do for j=0 to 7 do for k=0 to 7 do
267         if (is_zero (mth2 (mth p k).dd i j)) then rest 
268         else dcw.(i).(j) <- dcw.(i).(j) + mth fW_partial k * mth2 ((mth p k).dd) i j ;
269       done; done; done in
270       let len = List.length pW_partial in
271       let _ = for ki = 0 to len-1 do 
272         let (k,i,rki) = List.nth pW_partial ki in
273           for mj=0 to len-1 do
274             let (m,j,rmj) = List.nth pW_partial mj in
275 (*            Report.report (Printf.sprintf "k i m j rki rmj fuddkm = %d %d %d %d %f %f %f" k i m j rki.lo rmj.lo (mth2 fu.dd k m).lo); *)
276               dcw.(i).(j) <- dcw.(i).(j) +  mth2 fu.dd k m * rki * rmj; (* innermost loop *)
277           done; done in
278       let dcw_list =  map Array.to_list (Array.to_list dcw) in
279         make_taylor_interval(lin,w,dcw_list);;
280
281 let rec evalf4 tf w x y z = match tf with
282   | Prim_a p -> evalf4A p w x y z
283   | Uni p -> evalf4U p w x y z
284   | Plus (tf1,tf2) -> ti_add(evalf4 tf1 w x y z, evalf4 tf2 w x y z)
285   | Product (tf1,tf2) -> ti_mul(evalf4 tf1 w x y z, evalf4 tf2 w x y z)
286   | Composite(h,g1,g2,g3,g4,g5,g6,g7,g8) ->
287       let [p1;p2;p3;p4;p5;p6;p7;p8] = map (fun t-> evalf4 t w x y z) [g1;g2;g3;g4;g5;g6;g7;g8] in
288         eval_composite (evalf4 h) p1 p2 p3 p4 p5 p6 p7 p8 w
289   | Scale (tf,t) -> ti_scale ((evalf4 tf w x y z),t)
290   | Uni_compose (uf,tf) -> 
291       let ti = evalf4 tf w x y z in
292       let fy = ti.l.f in
293       let u_fy = uf.u fy in
294       let du_fy = uf.du fy in
295       let line = 
296         let ( * ) = imul in
297           mk_line (u_fy, table (fun i -> du_fy * mth ti.l.df i)) in
298       let fx = mk_interval (lower_bound ti, upper_bound ti) in
299       let dfx = table (fun i -> mk_interval (lower_partial ti i, upper_partial ti i)) in
300       let du_fx = uf.du fx in
301       let ddu_fx = uf.ddu fx in
302       let dd = table2 (fun i j ->
303                          let ( + ), ( * ) = iadd, imul in
304                            (ddu_fx * mth dfx j) * mth dfx i + du_fx * mth2 ti.dd j i) in
305         make_taylor_interval(line, w, dd);;
306         
307       
308
309 (*      evalf4 (Composite(Uni (mk_primitiveU 0 uf),tf,unit,unit,unit,unit,unit,unit,unit)) w x y z;; *)
310
311 let evalf tf x z = 
312   let (y,w) = center_form (x,z) in
313     evalf4 tf w x y z;;
314
315
316 (* Evaluates a function (i = 0) and its first derivatives (i = 1, 2, ...) at the given interval *)
317 let rec evalf0 tf i x z = match tf with
318   | Prim_a p -> p.f_df i x z
319   | Uni p -> 
320       let int = mk_interval (mth x p.slot, mth z p.slot) in
321         if i = 0 then eval p.uv int 0
322         else if i = p.slot + 1 then eval p.uv int 1
323         else zero
324   | Plus (tf1, tf2) -> iadd (evalf0 tf1 i x z) (evalf0 tf2 i x z)
325   | Product (tf1, tf2) -> 
326       let itf1, itf2 = evalf0 tf1 0 x z, evalf0 tf2 0 x z in
327         if i = 0 then imul itf1 itf2
328         else
329           let i_df1, i_df2 = evalf0 tf1 i x z, evalf0 tf2 i x z in
330             iadd (imul i_df1 itf2) (imul itf1 i_df2)
331   | Scale (tf, t) -> imul (evalf0 tf i x z) t
332   | Uni_compose (uf, tf) ->
333       let itf = evalf0 tf 0 x z in
334         if i = 0 then eval uf itf 0
335         else
336           let i_df = evalf0 tf i x z in
337             imul (eval uf itf 1) i_df
338   | Composite (h,g1,g2,g3,g4,g5,g6,g7,g8) ->
339       let gs = [g1;g2;g3;g4;g5;g6;g7;g8] in
340       let ps = map (fun t -> let int = evalf0 t 0 x z in int.lo, int.hi) gs in
341       let x', z' = unzip ps in
342         if i = 0 then evalf0 h 0 x' z'
343         else
344           let dhs = table (fun j -> evalf0 h (j + 1) x' z') in
345           let dgs = map (fun t -> evalf0 t i x z) gs in
346           let ( + ), ( * ) = iadd, imul in
347             itlist2 (fun a b c -> a * b + c) dhs dgs zero;;
348
349         
350 (*      
351 let line_estimate_composite =
352   let ( + ) = iadd in
353   let ( * ) = imul in
354     fun h p1 p2 p3 p4 p5 p6 p7 p8 ->
355       let p =  [p1;p2;p3;p4;p5;p6;p7;p8] in
356       let (a,b) = map (fun p -> p.f.lo) p, map (fun p -> p.f.hi) p in 
357       let fN = evalf h a b in
358       let fN_partial = table (fun i -> mk_interval(lower_partial fN i,upper_partial fN i)) in
359       let pN_partial =table2(fun i j-> (mth (mth p i).df j)) in
360       let cN_partial2 = table2 (fun i j -> mth fN_partial j * mth2 pN_partial j i) in
361       let cN_partial = map (end_itlist ( + )) cN_partial2 in
362         mk_line ( fN.l.f, cN_partial );;
363
364 let rec line_estimate tf y = match tf with
365   | Prim_a p -> line_estimateA p y
366   | Uni p -> line_estimateU p y
367   | Plus (p,q) -> ladd (line_estimate p y) (line_estimate q y)
368   | Scale (p,t) -> smul (line_estimate p y) t
369   | Uni_compose (uf,tf) -> 
370       line_estimate (Composite(Uni { slot=0; uv=uf; },tf,unit,unit,unit,unit,unit,unit,unit)) y
371   | Composite(h,g1,g2,g3,g4,g5,g6,g7,g8) ->
372       let [p1;p2;p3;p4;p5;p6;p7;p8] = map (fun t-> line_estimate t y) [g1;g2;g3;g4;g5;g6;g7;g8] in
373         line_estimate_composite h p1 p2 p3 p4 p5 p6 p7 p8;;
374 *)
375
376 end;;