Update from HH
[Flyspeck/.git] / port_interval / taylor.hl
1 (* port of taylor functions, taylor interval *)
2
3 (*
4 The first part of the file implements basic operations on type taylor_interval.
5
6 Then a type tfunction is defined that represents a twice continuously
7 differentiable function of six variables.  It can be evaluated, which
8 is the taylor_interval data associated with it.
9
10 Sometimes a tfunction f is used to represent an inequality f < 0.
11 (See recurse.hl.
12 *)
13
14
15 module Taylor = struct
16
17 open Interval;;
18 open Univariate;;
19 open Line_interval;;
20
21
22 (* general utilities *)
23
24
25 let m6_sum =
26   let ( + ) = iadd in
27     fun dd1 dd2 ->
28       let r6_sum (x,y) = table (fun i -> mth x i + mth y i)  in
29         map r6_sum (zip dd1 dd2);;
30
31 let center_form(x,z) =
32   let ( + ) , ( - ), ( / ) = up(); upadd,upsub,updiv in
33   let y = table (fun i -> if (mth x i=mth z i) then mth x i else (mth x i + mth z i)/ 2.0)  in
34   let w = table (fun i -> max (mth z i - mth y i) (mth y i - mth x i))  in
35   let _ = (minl w >= 0.0) or failwith "centerform" in
36      (y,w);;
37
38 (* start with taylor interval operations *)
39
40 let make_taylor_interval (l1,w1,dd1) = {l = l1; w = w1; dd=dd1;};;
41
42 let ti_add (ti1,ti2) =
43   let _ = (ti1.w = ti2.w) or failwith ("width mismatch in ti") in
44    make_taylor_interval( ladd ti1.l ti2.l,ti1.w, m6_sum ti1.dd ti2.dd);;
45
46 let ti_scale (ti,t) =
47    make_taylor_interval( smul ti.l t,ti.w,  table2 (fun i j ->  imul (mth2 ti.dd i j) t));;
48
49 let taylor_error ti =
50   let ( + ), ( * ) , ( / )= up(); upadd, upmul, updiv in
51   let dot_abs_row r = List.fold_left2 (fun a b c -> a + b * iabs c) 0.0 ti.w r in
52   let dots = map dot_abs_row (ti.dd) in
53     (List.fold_left2 (fun a b c -> a + b * c) 0.0 ti.w dots) / 2.0;;
54 (*  (end_itlist ( + ) p) / 2.0 ;; *)
55
56 let upper_bound ti = 
57   let e = taylor_error ti in
58   let ( + ), ( * ) = up(); upadd, upmul in
59   let t = ti.l.f.hi + e in
60     t + List.fold_left2 (fun a b c -> a + b * iabs c) 0.0 ti.w ti.l.df;;
61
62 let lower_bound ti = 
63   let e = taylor_error ti in
64   let ( + ), ( * ),(- ) = down(); downadd,downmul,downsub in
65   let t = ti.l.f.lo - e in
66     t + List.fold_left2 (fun a b c -> a + ( ~-. b) * iabs c) 0.0 ti.w ti.l.df;;
67
68 let upper_partial ti i = 
69   let ( + ), ( * ) =   up(); upadd,upmul in
70     let err = List.fold_left2 (fun a b c -> a + b*(max c.hi (~-. (c.lo)))) 
71       0.0 ti.w (mth ti.dd i) in
72       err + Interval.sup ( mth ti.l.df i);;
73
74 let lower_partial ti i = 
75   let ( + ), ( * ), ( - ) = down();downadd,downmul,downsub in
76     let err = List.fold_left2 (fun a b c -> a + b * min c.lo (~-. (c.hi))) 
77       0.0 ti.w (mth ti.dd i) in
78       Interval.inf ( mth ti.l.df i) + err;;
79
80
81 (* primitive A *)
82
83 type primitiveA = {
84   hfn : float list -> line;
85   second : float list -> float list -> interval list list;
86 };;
87
88 let make_primitiveA (h1,s1) = {hfn = h1; second = s1; };;
89
90 let unitA = 
91   let zero2 = table2 (fun i j -> zero) in
92     make_primitiveA (
93       (fun y -> line_unit),
94       (fun x z -> zero2)
95 );;
96
97 let evalf4A pA w x y z =
98   make_taylor_interval(
99     pA.hfn y,
100     w,
101     pA.second x z
102   );;
103
104 let line_estimateA pA y = pA.hfn y;;
105
106 (* primitive U *)
107
108 type primitiveU = {
109   slot: int;
110   uv: univariate;
111 };;
112
113 let mk_primitiveU s1 uv1 = 
114   let _ = (s1 < 6) or failwith (Printf.sprintf "slot %d" s1) in
115     { slot = s1; uv = uv1; };;
116
117 let line_estimateU p y = 
118   let y0 = mth y p.slot in
119   let t = mk_interval(y0,y0) in
120   let d = table (fun i -> if (i=p.slot) then eval p.uv t 1 else zero)  in
121     mk_line (    eval p.uv t 0,    d  );;
122
123 let evalf4U =
124   let row0 = table (fun i -> zero)  in
125     fun p w x y z ->
126       let t = mk_interval(mth x p.slot,mth z p.slot) in
127       let row_slot = table  (fun i -> if (i=p.slot) then eval p.uv t 2 else zero)  in
128       let dd = table (fun i -> if (i=p.slot) then row_slot else row0)  in
129       make_taylor_interval(
130         line_estimateU p y,
131         w,
132         dd
133       );;
134
135 type tfunction = 
136   |  Prim_a of primitiveA
137   |  Uni of primitiveU
138   |  Plus of tfunction * tfunction
139   |  Scale of tfunction * interval
140   |  Uni_compose of univariate * tfunction
141   |  Composite of tfunction *  (* F(g1,g2,g3,g4,g5,g6) *)
142        tfunction *tfunction *tfunction *
143        tfunction *tfunction *tfunction ;;
144
145 let unit = Prim_a unitA;;
146
147 let x1 = Uni (mk_primitiveU 0 ux1);;
148 let x2 = Uni (mk_primitiveU 1 ux1);;
149 let x3 = Uni (mk_primitiveU 2 ux1);;
150 let x4 = Uni (mk_primitiveU 3 ux1);;
151 let x5 = Uni (mk_primitiveU 4 ux1);;
152 let x6 = Uni (mk_primitiveU 5 ux1);;
153
154 let x1x2 = 
155   let tab2 = table2 (fun i j -> if (i+j=1) then one else zero) in
156     Prim_a (make_primitiveA(
157                      (fun y -> 
158                         let u1 = mth y 0 in 
159                         let u2 = mth y 1 in
160                         let x1 = mk_interval(u1,u1) in
161                         let x2 = mk_interval(u2,u2) in
162                           mk_line(
163                             imul x1 x2,
164                             table (fun i -> if i=0 then x2 else if i=1 then x1 else zero)
165                           )),
166                      (fun x z -> tab2)));;
167
168 let rotate2 f = Composite(f,x2,x3,x1,x5,x6,x4);;
169 let rotate3 f = Composite(f,x3,x1,x2,x6,x4,x5);;
170 let rotate4 f = Composite(f,x4,x2,x6,x1,x5,x3);;
171 let rotate5 f = Composite(f,x5,x3,x4,x2,x6,x1);;
172 let rotate6 f = Composite(f,x6,x1,x5,x3,x4,x2);;
173
174 let tf_product tf1 tf2 = Composite(x1x2,tf1,tf2,unit,unit,unit,unit);;
175   
176
177 (* This is one of the most difficult functions in the interval code.
178    It uses the chain rule to compute the second partial derivatives with
179    respect to x(i) x(j), of a function composition
180
181    F(x1,...,x6) = f(g1(x1,...x6),g2(x1,...x6),...,g6(x1,...x6)).
182
183    (F i j) = sum {k m} (f k m) (gk i) (gm j)     + sum {r} (f r) (gr i j).
184
185    Fast performance of this function is very important, especially
186    when many of the functions g* are constant.
187    There is a bit of imperative programming here, in computing the sums.
188
189    Note that ( + ) and ( * ) have different types in various subsections.
190 *)
191
192 let eval_composite =
193   let rest = () in
194   let  sparse_table h f = filter h (List.flatten (table2 f)) in
195     fun hdr p1 p2 p3 p4 p5 p6 w ->
196       let p = [p1;p2;p3;p4;p5;p6] in
197         (* wide and narrow ranges of p *)
198       let (aw,bw) = map (lower_bound) p, map (upper_bound) p  in 
199       let (a,b) = map (fun p -> p.l.f.lo) p, map (fun p -> p.l.f.hi) p in 
200         (* wide and narrow widths from a to b *)
201       let (u,wu,wf) = 
202         let ( + ),( - ),( / ) = up();upadd,upsub,updiv in
203         let u = table (fun i -> (mth a i + mth b i) / 2.0)  in
204         let wu = table (fun i -> max (mth bw i - mth u i) (mth u i - mth aw i))  in
205         let wf = table (fun i -> max (mth b i - mth u i) (mth u i - mth a i))  in
206           (u,wu,wf) in
207       let (fu:taylor_interval) = hdr wu aw u bw in
208       let fpy = 
209         let t = make_taylor_interval(fu.l,wf,fu.dd) in
210           mk_line (
211             mk_interval(lower_bound t, upper_bound t),
212             table (fun i -> mk_interval(lower_partial t i,upper_partial t i))  ) in
213         (* use chain rule imperatively to compute narrow first derivative *)
214       let df_tmp = Array.create 6 zero in
215       let ( + ) = iadd in
216       let ( * ) = imul in
217       let _ = for j=0 to 5 do 
218         let dfj = mth fpy.df j in
219           if is_zero dfj then rest 
220           else for i=0 to 5 do
221             let r = mth (mth p j).l.df i in
222               if (is_zero r) then rest else df_tmp.(i) <- df_tmp.(i) + dfj * r;
223           done;
224       done in
225       let lin = mk_line (       fpy.f, Array.to_list df_tmp ) in
226         (* second derivative init *)
227       let fW_partial = table (fun i -> mk_interval(lower_partial fu i,upper_partial fu i)) in
228       let pW_partial = sparse_table (fun (_,_,z) ->not (is_zero z))  
229         (fun k i -> (k,i,(mk_interval(lower_partial (mth p k) i,upper_partial (mth p k) i)))) in
230         (* chain rule 4-nested loop!, but flattened with sparse table *)
231       let dcw = Array.make_matrix 6 6 zero in 
232       let _ = for i=0 to 5 do for j=0 to 5 do for k=0 to 5 do
233         if (is_zero (mth2 (mth p k).dd i j)) then rest 
234         else dcw.(i).(j) <- dcw.(i).(j) + mth fW_partial k * mth2 ((mth p k).dd) i j ;
235       done; done; done in
236       let len = List.length pW_partial in
237       let _ = for ki = 0 to len-1 do 
238         let (k,i,rki) = List.nth pW_partial ki in
239           for mj=0 to len-1 do
240             let (m,j,rmj) = List.nth pW_partial mj in
241               Report.report (Printf.sprintf "k i m j rki rmj fuddkm = %d %d %d %d %f %f %f" k i m j rki.lo rmj.lo (mth2 fu.dd k m).lo);
242               dcw.(i).(j) <- dcw.(i).(j) +  mth2 fu.dd k m * rki * rmj; (* innermost loop *)
243           done; done in
244       let dcw_list =  map Array.to_list (Array.to_list dcw) in
245         make_taylor_interval(lin,w,dcw_list);;
246
247 let rec evalf4 tf w x y z = match tf with
248   | Prim_a p -> evalf4A p w x y z
249   | Uni p -> evalf4U p w x y z
250   | Plus (tf1,tf2) -> ti_add(evalf4 tf1 w x y z, evalf4 tf2 w x y z)
251   | Composite(h,g1,g2,g3,g4,g5,g6) ->
252       let [p1;p2;p3;p4;p5;p6] = map (fun t-> evalf4 t w x y z) [g1;g2;g3;g4;g5;g6] in
253         eval_composite (evalf4 h) p1 p2 p3 p4 p5 p6 w
254   | Scale (tf,t) -> ti_scale ((evalf4 tf w x y z),t)
255   | Uni_compose (uf,tf) -> 
256       evalf4 (Composite(Uni (mk_primitiveU 0 uf),tf,unit,unit,unit,unit,unit)) w x y z;;
257
258 let evalf tf x z = 
259   let (y,w) = center_form (x,z) in
260     evalf4 tf w x y z;;
261
262 let line_estimate_composite =
263   let ( + ) = iadd in
264   let ( * ) = imul in
265     fun h p1 p2 p3 p4 p5 p6 ->
266       let p =  [p1;p2;p3;p4;p5;p6] in
267       let (a,b) = map (fun p -> p.f.lo) p, map (fun p -> p.f.hi) p in 
268       let fN = evalf h a b in
269       let fN_partial = table (fun i -> mk_interval(lower_partial fN i,upper_partial fN i)) in
270       let pN_partial =table2(fun i j-> (mth (mth p i).df j)) in
271       let cN_partial2 = table2 (fun i j -> mth fN_partial j * mth2 pN_partial j i) in
272       let cN_partial = map (end_itlist ( + )) cN_partial2 in
273         mk_line ( fN.l.f, cN_partial );;
274
275 let rec line_estimate tf y = match tf with
276   | Prim_a p -> line_estimateA p y
277   | Uni p -> line_estimateU p y
278   | Plus (p,q) -> ladd (line_estimate p y) (line_estimate q y)
279   | Scale (p,t) -> smul (line_estimate p y) t
280   | Uni_compose (uf,tf) -> 
281       line_estimate (Composite(Uni { slot=0; uv=uf; },tf,unit,unit,unit,unit,unit)) y
282   | Composite(h,g1,g2,g3,g4,g5,g6) ->
283       let [p1;p2;p3;p4;p5;p6] = map (fun t-> line_estimate t y) [g1;g2;g3;g4;g5;g6] in
284         line_estimate_composite h p1 p2 p3 p4 p5 p6;;
285
286 end;;