(* Thomas Hales, June 29, 2011,
   implementation of the rank boost algorithm from
   "An Efficient Boosting Algorithm for Combining Preferences"
   Freund, Iyer, Schapire, and Singer.

   The 'bottom' symbol _|_ is implemented by throwing an exception.

*)


module Rank_boost = struct

(**********************************************************************)
(* matrix ops *)

let exp = Pervasives.exp;; 

let abs = Pervasives.abs_float;;

let infinity = Pervasives.infinity;;

let neg_infinity = Pervasives.neg_infinity;;

let nth = List.nth;;

let scale c = map ( ( *. ) c) ;;

let outer_col_mul d v = 
  map (map2 ( *. ) v) d;;

let outer_row_mul d v = 
  map2  scale  v d;;

let sum = end_itlist (+.);;

(* examples *)

outer_col_mul [[1.0;2.0];[3.0;4.0]] [7.0;11.0];;
outer_row_mul [[1.0;2.0];[3.0;4.0]] [7.0;11.0];;

(**********************************************************************)
(* boost utilities *)

let alpha_r r = 0.5 *. log ((1.0 +. r) /. (1.0 -. r));;

let normalize d = 
  let total = sum (map sum d) in
    map (scale (1.0 /. total)) d;;

let mk_distrib phi X = 
  let row y = map (fun x -> max 0.0 (phi y x)) X in
    normalize(map row X);;

(*
let mk_uniform_distribution n = 
  let row = replicate 1.0 n in
    normalize (replicate row n)  ;;
*)

let update_distribution distrib alpha hs = 
  let h1 = map (fun t -> exp(alpha *. t)) hs in
  let h2 = map (fun t -> 1.0 /. t) h1 in
  let d =outer_col_mul distrib h2 in
  let d = outer_row_mul d h1 in
    normalize d;;

let sort_zipf f zX = 
  let zf = mapfilter (fun (i,x) -> (i,f x)) zX in
    sort (fun (_,f1) (_,f2) -> f1 > f2) zf;;

let mk_pi distrib =
  let row_sums = map sum distrib in
  let col_sums = end_itlist (map2 (+.)) distrib in
     map2 ( -. ) row_sums col_sums;;

(* weak_update
  might behave slightly incorrectly when the optimal f' is a repeated value of f.
    There is no problem, if the ranking f is strict.  

   default_q currently doesn't get used in the end. *)
    
let  weak_update R pi default_q (L,r,theta,q) (i',f') =
  let L  = L +. nth pi i' in
  let q' = match default_q with
    Some q' -> q'
    | None -> if (abs(L) > abs(L -. R)) then 0.0 else 1.0 in
    if abs(L -. q *. R) > abs(r) then (L,L-. q*. R,f',q') else (L,r,theta,q);;

let weak_learn_one pi default_q zf =
  let R = List.fold_left (fun s (i,_) -> s +. nth pi i) 0.0 zf  in
  List.fold_left (weak_update R pi default_q) (0.0,0.0,Pervasives.infinity,0.0) zf  ;;

let weak_learn_all distrib ranks sfs =
  let pi = mk_pi distrib in
  let wl = weak_learn_one pi None in
  let foldf (r,absr,fx,q,f) (zf,f') = 
    let (_,r',fx',q') = wl zf in
      (if (abs r' > absr) then (r',abs r',fx',q',f') else (r,absr,fx,q,f))  in
  let (r,_,theta,q,f) =   List.fold_left foldf
    (0.0,neg_infinity,0.0,0.0,(fun t ->0.0))
    (zip sfs ranks) in
  let alpha = alpha_r r in
    (alpha,q,theta,f);;

let learned_h q theta f x =
  try ( if f x >= theta then 1.0 else 0.0) with _ -> q;;

(* example: *)

sort_zipf I (zipX [2.0;3.0;4.0]);;
sort_zipf (fun t -> if t < 3.0 then t else failwith "bad") (zipX [7.0;4.0;3.0;2.0;2.5;1.0]);;
mk_pi [[1.0;2.0];[3.0;4.0]];;

(**********************************************************************)
(* running the boost algorithm  T times *)

(*
X:A = domain
ranks:(A->real) list, (ranking functions)
phi:A -> A-> float, feedback function.
t=number of iterations, typically 40-150 range. paper uses 40+n/10, n=length ranks.
We don't preset theta, we compute it as some value fx.
*)

let rec rank_boost t ranks wts distrib sfs X = 
  if (t=List.length wts) then wts
  else
    let (alpha,q,theta,f)=weak_learn_all distrib ranks sfs in
    let h = learned_h q theta f in
    let distrib' =  update_distribution distrib alpha (map h X) in
      rank_boost t ranks ((alpha,q,theta,f)::wts) distrib' sfs X;;

let calc_wts X phi ranks t = 
  let distrib = mk_distrib phi  X in
  let zipX = zip (0--(List.length X - 1)) X in
  let sfs =   map (fun f -> sort_zipf f zipX)  ranks in
    rank_boost t ranks [] distrib sfs X;;

let evalf wts x = List.fold_left 
  (fun s (alpha,q,theta,f) -> s +. alpha *. learned_h q theta f x) 0.0 wts;;

end;;