Update from HH
[Flyspeck/.git] / formal_lp / old / formal_interval / interval_m / taylor.hl
1 (* port of taylor functions, taylor interval *)
2
3 (*
4 The first part of the file implements basic operations on type taylor_interval.
5
6 Then a type tfunction is defined that represents a twice continuously
7 differentiable function of six variables.  It can be evaluated, which
8 is the taylor_interval data associated with it.
9
10 Sometimes a tfunction f is used to represent an inequality f < 0.
11 (See recurse.hl.
12 *)
13
14
15 module Taylor = struct
16
17 open Interval;;
18 open Univariate;;
19 open Line_interval;;
20
21
22 (* general utilities *)
23
24
25 let m8_sum =
26   let ( + ) = iadd in
27     fun dd1 dd2 ->
28       let r8_sum (x,y) = table (fun i -> mth x i + mth y i)  in
29         map r8_sum (zip dd1 dd2);;
30
31 let center_form(x,z) =
32   let ( + ) , ( - ), ( / ) = up(); upadd,upsub,updiv in
33   let y = table (fun i -> if (mth x i=mth z i) then mth x i else (mth x i + mth z i)/ 2.0)  in
34   let w = table (fun i -> max (mth z i - mth y i) (mth y i - mth x i))  in
35   let _ = (minl w >= 0.0) or failwith "centerform" in
36      (y,w);;
37
38 (* start with taylor interval operations *)
39
40 let make_taylor_interval (l1,w1,dd1) = {l = l1; w = w1; dd=dd1;};;
41
42 let ti_add (ti1,ti2) =
43   let _ = (ti1.w = ti2.w) or failwith ("width mismatch in ti") in
44     make_taylor_interval( ladd ti1.l ti2.l,ti1.w, m8_sum ti1.dd ti2.dd);;
45
46 let ti_scale (ti,t) =
47    make_taylor_interval( smul ti.l t,ti.w,  table2 (fun i j ->  imul (mth2 ti.dd i j) t));;
48
49
50
51 let taylor_error ti =
52   let ( + ), ( * ) , ( / )= up(); upadd, upmul, updiv in
53   let dot_abs_row r = List.fold_left2 (fun a b c -> a + b * iabs c) 0.0 ti.w r in
54   let dots = map dot_abs_row (ti.dd) in
55     (List.fold_left2 (fun a b c -> a + b * c) 0.0 ti.w dots) / 2.0;;
56 (*  (end_itlist ( + ) p) / 2.0 ;; *)
57
58 let upper_bound ti = 
59   let e = taylor_error ti in
60   let ( + ), ( * ) = up(); upadd, upmul in
61   let t = ti.l.f.hi + e in
62     t + List.fold_left2 (fun a b c -> a + b * iabs c) 0.0 ti.w ti.l.df;;
63
64 let lower_bound ti = 
65   let e = taylor_error ti in
66   let ( + ), ( * ),(- ) = down(); downadd,downmul,downsub in
67   let t = ti.l.f.lo - e in
68     t + List.fold_left2 (fun a b c -> a + ( ~-. b) * iabs c) 0.0 ti.w ti.l.df;;
69
70 let upper_partial ti i = 
71   let ( + ), ( * ) =   up(); upadd,upmul in
72     let err = List.fold_left2 (fun a b c -> a + b*(max c.hi (~-. (c.lo)))) 
73       0.0 ti.w (mth ti.dd i) in
74       err + Interval.sup ( mth ti.l.df i);;
75
76 let lower_partial ti i = 
77   let ( + ), ( * ), ( - ) = down();downadd,downmul,downsub in
78     let err = List.fold_left2 (fun a b c -> a + b * min c.lo (~-. (c.hi))) 
79       0.0 ti.w (mth ti.dd i) in
80       Interval.inf ( mth ti.l.df i) + err;;
81
82
83 let ti_mul (ti1,ti2) =
84   let _ = (ti1.w = ti2.w) or failwith ("ti_mul: width mismatch in ti") in
85   let line = lmul ti1.l ti2.l in
86   let f1_int =
87     let lo, hi = lower_bound ti1, upper_bound ti1 in mk_interval (lo, hi) in
88   let f2_int =
89     let lo, hi = lower_bound ti2, upper_bound ti2 in mk_interval (lo, hi) in
90   let d1_ints = table (fun i -> mk_interval (lower_partial ti1 i, upper_partial ti1 i)) in
91   let d2_ints = table (fun i -> mk_interval (lower_partial ti2 i, upper_partial ti2 i)) in
92   let dd = table2 (fun i j ->
93                      let ( + ), ( * ) = iadd, imul in
94                        mth2 ti1.dd i j * f2_int + mth d1_ints i * mth d2_ints j +
95                          mth d1_ints j * mth d2_ints i + f1_int * mth2 ti2.dd i j) in
96     make_taylor_interval(line, ti1.w, dd);;
97     
98
99
100 (* primitive A *)
101
102 type primitiveA = {
103   f_df : int -> float list -> float list -> interval;
104   hfn : float list -> line;
105   second : float list -> float list -> interval list list;
106 };;
107
108 let make_primitiveA (f,h1,s1) = {f_df = f; hfn = h1; second = s1; };;
109
110 let unitA = 
111   let zero2 = table2 (fun i j -> zero) in
112     make_primitiveA (
113       (fun i x z -> if i = 0 then one else zero),
114       (fun y -> line_unit),
115       (fun x z -> zero2)
116 );;
117
118 let evalf4A pA w x y z =
119   make_taylor_interval(
120     pA.hfn y,
121     w,
122     pA.second x z
123   );;
124
125 let line_estimateA pA y = pA.hfn y;;
126
127 (* primitive U *)
128
129 type primitiveU = {
130   slot: int;
131   uv: univariate;
132 };;
133
134 let mk_primitiveU s1 uv1 = 
135   let _ = (s1 < 8) or failwith (Printf.sprintf "slot %d" s1) in
136     { slot = s1; uv = uv1; };;
137
138 let line_estimateU p y = 
139   let y0 = mth y p.slot in
140   let t = mk_interval(y0,y0) in
141   let d = table (fun i -> if (i=p.slot) then eval p.uv t 1 else zero)  in
142     mk_line (    eval p.uv t 0,    d  );;
143
144 let evalf4U =
145   let row0 = table (fun i -> zero)  in
146     fun p w x y z ->
147       let t = mk_interval(mth x p.slot,mth z p.slot) in
148       let row_slot = table  (fun i -> if (i=p.slot) then eval p.uv t 2 else zero)  in
149       let dd = table (fun i -> if (i=p.slot) then row_slot else row0)  in
150       make_taylor_interval(
151         line_estimateU p y,
152         w,
153         dd
154       );;
155
156 type tfunction = 
157   | Prim_a of primitiveA
158   | Uni of primitiveU
159   | Plus of tfunction * tfunction
160   | Product of tfunction * tfunction
161   | Scale of tfunction * interval
162   | Uni_compose of univariate * tfunction
163   | Composite of tfunction *  (* F(g1,g2,g3,g4,g5,g6,g7,g8) *)
164       tfunction *tfunction *tfunction *
165       tfunction *tfunction *tfunction *
166       tfunction *tfunction;;
167
168 let unit = Prim_a unitA;;
169
170 let x1 = Uni (mk_primitiveU 0 ux1);;
171 let x2 = Uni (mk_primitiveU 1 ux1);;
172 let x3 = Uni (mk_primitiveU 2 ux1);;
173 let x4 = Uni (mk_primitiveU 3 ux1);;
174 let x5 = Uni (mk_primitiveU 4 ux1);;
175 let x6 = Uni (mk_primitiveU 5 ux1);;
176
177
178 let x1x2 = 
179   let tab2 = table2 (fun i j -> if (i+j=1) then one else zero) in
180     Prim_a (make_primitiveA(
181               (fun i x z -> 
182                  let x1 = mk_interval (mth x 0, mth z 0) in
183                  let x2 = mk_interval (mth x 1, mth z 1) in
184                    if i = 0 then imul x1 x2
185                    else if i = 1 then x2
186                    else if i = 2 then x1
187                    else zero),
188               (fun y -> 
189                  let u1 = mth y 0 in 
190                  let u2 = mth y 1 in
191                  let x1 = mk_interval(u1,u1) in
192                  let x2 = mk_interval(u2,u2) in
193                    mk_line(
194                      imul x1 x2,
195                      table (fun i -> if i=0 then x2 else if i=1 then x1 else zero)
196                    )),
197               (fun x z -> tab2)));;
198
199 let tf_product tf1 tf2 = Composite(x1x2,tf1,tf2,unit,unit,unit,unit,unit,unit);;
200   
201
202 (* This is one of the most difficult functions in the interval code.
203    It uses the chain rule to compute the second partial derivatives with
204    respect to x(i) x(j), of a function composition
205
206    F(x1,...,x6) = f(g1(x1,...x6),g2(x1,...x6),...,g6(x1,...x6)).
207
208    (F i j) = sum {k m} (f k m) (gk i) (gm j)     + sum {r} (f r) (gr i j).
209
210    Fast performance of this function is very important, especially
211    when many of the functions g* are constant.
212    There is a bit of imperative programming here, in computing the sums.
213
214    Note that ( + ) and ( * ) have different types in various subsections.
215 *)
216
217 let eval_composite =
218   let rest = () in
219   let  sparse_table h f = filter h (List.flatten (table2 f)) in
220     fun hdr p1 p2 p3 p4 p5 p6 p7 p8 w ->
221       let p = [p1;p2;p3;p4;p5;p6;p7;p8] in
222         (* wide and narrow ranges of p *)
223       let (aw,bw) = map (lower_bound) p, map (upper_bound) p  in 
224       let (a,b) = map (fun p -> p.l.f.lo) p, map (fun p -> p.l.f.hi) p in 
225         (* wide and narrow widths from a to b *)
226       let (u,wu,wf) = 
227         let ( + ),( - ),( / ) = up();upadd,upsub,updiv in
228         let u = table (fun i -> (mth a i + mth b i) / 2.0)  in
229         let wu = table (fun i -> max (mth bw i - mth u i) (mth u i - mth aw i))  in
230         let wf = table (fun i -> max (mth b i - mth u i) (mth u i - mth a i))  in
231           (u,wu,wf) in
232       let (fu:taylor_interval) = hdr wu aw u bw in
233       let fpy = 
234         let t = make_taylor_interval(fu.l,wf,fu.dd) in
235           mk_line (
236             mk_interval(lower_bound t, upper_bound t),
237             table (fun i -> mk_interval(lower_partial t i,upper_partial t i))  ) in
238         (* use chain rule imperatively to compute narrow first derivative *)
239       let df_tmp = Array.create 8 zero in
240       let ( + ) = iadd in
241       let ( * ) = imul in
242       let _ = for j=0 to 7 do 
243         let dfj = mth fpy.df j in
244           if is_zero dfj then rest 
245           else for i=0 to 7 do
246             let r = mth (mth p j).l.df i in
247               if (is_zero r) then rest else df_tmp.(i) <- df_tmp.(i) + dfj * r;
248           done;
249       done in
250       let lin = mk_line (       fpy.f, Array.to_list df_tmp ) in
251         (* second derivative init *)
252       let fW_partial = table (fun i -> mk_interval(lower_partial fu i,upper_partial fu i)) in
253       let pW_partial = sparse_table (fun (_,_,z) ->not (is_zero z))  
254         (fun k i -> (k,i,(mk_interval(lower_partial (mth p k) i,upper_partial (mth p k) i)))) in
255         (* chain rule 4-nested loop!, but flattened with sparse table *)
256       let dcw = Array.make_matrix 8 8 zero in 
257       let _ = for i=0 to 7 do for j=0 to 7 do for k=0 to 7 do
258         if (is_zero (mth2 (mth p k).dd i j)) then rest 
259         else dcw.(i).(j) <- dcw.(i).(j) + mth fW_partial k * mth2 ((mth p k).dd) i j ;
260       done; done; done in
261       let len = List.length pW_partial in
262       let _ = for ki = 0 to len-1 do 
263         let (k,i,rki) = List.nth pW_partial ki in
264           for mj=0 to len-1 do
265             let (m,j,rmj) = List.nth pW_partial mj in
266 (*            Report.report (Printf.sprintf "k i m j rki rmj fuddkm = %d %d %d %d %f %f %f" k i m j rki.lo rmj.lo (mth2 fu.dd k m).lo); *)
267               dcw.(i).(j) <- dcw.(i).(j) +  mth2 fu.dd k m * rki * rmj; (* innermost loop *)
268           done; done in
269       let dcw_list =  map Array.to_list (Array.to_list dcw) in
270         make_taylor_interval(lin,w,dcw_list);;
271
272 let rec evalf4 tf w x y z = match tf with
273   | Prim_a p -> evalf4A p w x y z
274   | Uni p -> evalf4U p w x y z
275   | Plus (tf1,tf2) -> ti_add(evalf4 tf1 w x y z, evalf4 tf2 w x y z)
276   | Product (tf1,tf2) -> ti_mul(evalf4 tf1 w x y z, evalf4 tf2 w x y z)
277   | Composite(h,g1,g2,g3,g4,g5,g6,g7,g8) ->
278       let [p1;p2;p3;p4;p5;p6;p7;p8] = map (fun t-> evalf4 t w x y z) [g1;g2;g3;g4;g5;g6;g7;g8] in
279         eval_composite (evalf4 h) p1 p2 p3 p4 p5 p6 p7 p8 w
280   | Scale (tf,t) -> ti_scale ((evalf4 tf w x y z),t)
281   | Uni_compose (uf,tf) -> 
282       let ti = evalf4 tf w x y z in
283       let fy = ti.l.f in
284       let u_fy = uf.u fy in
285       let du_fy = uf.du fy in
286       let line = 
287         let ( * ) = imul in
288           mk_line (u_fy, table (fun i -> du_fy * mth ti.l.df i)) in
289       let fx = mk_interval (lower_bound ti, upper_bound ti) in
290       let dfx = table (fun i -> mk_interval (lower_partial ti i, upper_partial ti i)) in
291       let du_fx = uf.du fx in
292       let ddu_fx = uf.ddu fx in
293       let dd = table2 (fun i j ->
294                          let ( + ), ( * ) = iadd, imul in
295                            (ddu_fx * mth dfx j) * mth dfx i + du_fx * mth2 ti.dd j i) in
296         make_taylor_interval(line, w, dd);;
297         
298       
299
300 (*      evalf4 (Composite(Uni (mk_primitiveU 0 uf),tf,unit,unit,unit,unit,unit,unit,unit)) w x y z;; *)
301
302 let evalf tf x z = 
303   let (y,w) = center_form (x,z) in
304     evalf4 tf w x y z;;
305
306
307 (* Evaluates a function (i = 0) and its first derivatives (i = 1, 2, ...) at the given interval *)
308 let rec evalf0 tf i x z = match tf with
309   | Prim_a p -> p.f_df i x z
310   | Uni p -> 
311       let int = mk_interval (mth x p.slot, mth z p.slot) in
312         if i = 0 then eval p.uv int 0
313         else if i = p.slot + 1 then eval p.uv int 1
314         else zero
315   | Plus (tf1, tf2) -> iadd (evalf0 tf1 i x z) (evalf0 tf2 i x z)
316   | Product (tf1, tf2) -> 
317       let itf1, itf2 = evalf0 tf1 0 x z, evalf0 tf2 0 x z in
318         if i = 0 then imul itf1 itf2
319         else
320           let i_df1, i_df2 = evalf0 tf1 i x z, evalf0 tf2 i x z in
321             iadd (imul i_df1 itf2) (imul itf1 i_df2)
322   | Scale (tf, t) -> imul (evalf0 tf i x z) t
323   | Uni_compose (uf, tf) ->
324       let itf = evalf0 tf 0 x z in
325         if i = 0 then eval uf itf 0
326         else
327           let i_df = evalf0 tf i x z in
328             imul (eval uf itf 1) i_df
329   | Composite (h,g1,g2,g3,g4,g5,g6,g7,g8) ->
330       let gs = [g1;g2;g3;g4;g5;g6;g7;g8] in
331       let ps = map (fun t -> let int = evalf0 t 0 x z in int.lo, int.hi) gs in
332       let x', z' = unzip ps in
333         if i = 0 then evalf0 h 0 x' z'
334         else
335           let dhs = table (fun j -> evalf0 h (j + 1) x' z') in
336           let dgs = map (fun t -> evalf0 t i x z) gs in
337           let ( + ), ( * ) = iadd, imul in
338             itlist2 (fun a b c -> a * b + c) dhs dgs zero;;
339
340         
341 (*      
342 let line_estimate_composite =
343   let ( + ) = iadd in
344   let ( * ) = imul in
345     fun h p1 p2 p3 p4 p5 p6 p7 p8 ->
346       let p =  [p1;p2;p3;p4;p5;p6;p7;p8] in
347       let (a,b) = map (fun p -> p.f.lo) p, map (fun p -> p.f.hi) p in 
348       let fN = evalf h a b in
349       let fN_partial = table (fun i -> mk_interval(lower_partial fN i,upper_partial fN i)) in
350       let pN_partial =table2(fun i j-> (mth (mth p i).df j)) in
351       let cN_partial2 = table2 (fun i j -> mth fN_partial j * mth2 pN_partial j i) in
352       let cN_partial = map (end_itlist ( + )) cN_partial2 in
353         mk_line ( fN.l.f, cN_partial );;
354
355 let rec line_estimate tf y = match tf with
356   | Prim_a p -> line_estimateA p y
357   | Uni p -> line_estimateU p y
358   | Plus (p,q) -> ladd (line_estimate p y) (line_estimate q y)
359   | Scale (p,t) -> smul (line_estimate p y) t
360   | Uni_compose (uf,tf) -> 
361       line_estimate (Composite(Uni { slot=0; uv=uf; },tf,unit,unit,unit,unit,unit,unit,unit)) y
362   | Composite(h,g1,g2,g3,g4,g5,g6,g7,g8) ->
363       let [p1;p2;p3;p4;p5;p6;p7;p8] = map (fun t-> line_estimate t y) [g1;g2;g3;g4;g5;g6;g7;g8] in
364         line_estimate_composite h p1 p2 p3 p4 p5 p6 p7 p8;;
365 *)
366
367 end;;