Update from HH
[Flyspeck/.git] / development / thales / ocaml / rank_boost.hl
1 (* Thomas Hales, June 29, 2011,
2    implementation of the rank boost algorithm from
3    "An Efficient Boosting Algorithm for Combining Preferences"
4    Freund, Iyer, Schapire, and Singer.
5
6    The 'bottom' symbol _|_ is implemented by throwing an exception.
7
8 *)
9
10
11 module Rank_boost = struct
12
13 (**********************************************************************)
14 (* matrix ops *)
15
16 let exp = Pervasives.exp;; 
17
18 let abs = Pervasives.abs_float;;
19
20 let infinity = Pervasives.infinity;;
21
22 let neg_infinity = Pervasives.neg_infinity;;
23
24 let nth = List.nth;;
25
26 let scale c = map ( ( *. ) c) ;;
27
28 let outer_col_mul d v = 
29   map (map2 ( *. ) v) d;;
30
31 let outer_row_mul d v = 
32   map2  scale  v d;;
33
34 let sum = end_itlist (+.);;
35
36 (* examples *)
37
38 outer_col_mul [[1.0;2.0];[3.0;4.0]] [7.0;11.0];;
39 outer_row_mul [[1.0;2.0];[3.0;4.0]] [7.0;11.0];;
40
41 (**********************************************************************)
42 (* boost utilities *)
43
44 let alpha_r r = 0.5 *. log ((1.0 +. r) /. (1.0 -. r));;
45
46 let normalize d = 
47   let total = sum (map sum d) in
48     map (scale (1.0 /. total)) d;;
49
50 let mk_distrib phi X = 
51   let row y = map (fun x -> max 0.0 (phi y x)) X in
52     normalize(map row X);;
53
54 (*
55 let mk_uniform_distribution n = 
56   let row = replicate 1.0 n in
57     normalize (replicate row n)  ;;
58 *)
59
60 let update_distribution distrib alpha hs = 
61   let h1 = map (fun t -> exp(alpha *. t)) hs in
62   let h2 = map (fun t -> 1.0 /. t) h1 in
63   let d =outer_col_mul distrib h2 in
64   let d = outer_row_mul d h1 in
65     normalize d;;
66
67 let sort_zipf f zX = 
68   let zf = mapfilter (fun (i,x) -> (i,f x)) zX in
69     sort (fun (_,f1) (_,f2) -> f1 > f2) zf;;
70
71 let mk_pi distrib =
72   let row_sums = map sum distrib in
73   let col_sums = end_itlist (map2 (+.)) distrib in
74      map2 ( -. ) row_sums col_sums;;
75
76 (* weak_update
77   might behave slightly incorrectly when the optimal f' is a repeated value of f.
78     There is no problem, if the ranking f is strict.  
79
80    default_q currently doesn't get used in the end. *)
81     
82 let  weak_update R pi default_q (L,r,theta,q) (i',f') =
83   let L  = L +. nth pi i' in
84   let q' = match default_q with
85     Some q' -> q'
86     | None -> if (abs(L) > abs(L -. R)) then 0.0 else 1.0 in
87     if abs(L -. q *. R) > abs(r) then (L,L-. q*. R,f',q') else (L,r,theta,q);;
88
89 let weak_learn_one pi default_q zf =
90   let R = List.fold_left (fun s (i,_) -> s +. nth pi i) 0.0 zf  in
91   List.fold_left (weak_update R pi default_q) (0.0,0.0,Pervasives.infinity,0.0) zf  ;;
92
93 let weak_learn_all distrib ranks sfs =
94   let pi = mk_pi distrib in
95   let wl = weak_learn_one pi None in
96   let foldf (r,absr,fx,q,f) (zf,f') = 
97     let (_,r',fx',q') = wl zf in
98       (if (abs r' > absr) then (r',abs r',fx',q',f') else (r,absr,fx,q,f))  in
99   let (r,_,theta,q,f) =   List.fold_left foldf
100     (0.0,neg_infinity,0.0,0.0,(fun t ->0.0))
101     (zip sfs ranks) in
102   let alpha = alpha_r r in
103     (alpha,q,theta,f);;
104
105 let learned_h q theta f x =
106   try ( if f x >= theta then 1.0 else 0.0) with _ -> q;;
107
108 (* example: *)
109
110 sort_zipf I (zipX [2.0;3.0;4.0]);;
111 sort_zipf (fun t -> if t < 3.0 then t else failwith "bad") (zipX [7.0;4.0;3.0;2.0;2.5;1.0]);;
112 mk_pi [[1.0;2.0];[3.0;4.0]];;
113
114 (**********************************************************************)
115 (* running the boost algorithm  T times *)
116
117 (*
118 X:A = domain
119 ranks:(A->real) list, (ranking functions)
120 phi:A -> A-> float, feedback function.
121 t=number of iterations, typically 40-150 range. paper uses 40+n/10, n=length ranks.
122 We don't preset theta, we compute it as some value fx.
123 *)
124
125 let rec rank_boost t ranks wts distrib sfs X = 
126   if (t=List.length wts) then wts
127   else
128     let (alpha,q,theta,f)=weak_learn_all distrib ranks sfs in
129     let h = learned_h q theta f in
130     let distrib' =  update_distribution distrib alpha (map h X) in
131       rank_boost t ranks ((alpha,q,theta,f)::wts) distrib' sfs X;;
132
133 let calc_wts X phi ranks t = 
134   let distrib = mk_distrib phi  X in
135   let zipX = zip (0--(List.length X - 1)) X in
136   let sfs =   map (fun f -> sort_zipf f zipX)  ranks in
137     rank_boost t ranks [] distrib sfs X;;
138
139 let evalf wts x = List.fold_left 
140   (fun s (alpha,q,theta,f) -> s +. alpha *. learned_h q theta f x) 0.0 wts;;
141
142 end;;