(*****************************)
(* Cached natural arithmetic *)
(*****************************)

(* Dependencies *)
needs "../formal_lp/arith/arith_options.hl";;
needs ("../formal_lp/arith/"^(if Arith_options.binary then "arith_hash.hl" else "arith_hash2.hl"));;


module Arith_cache = struct

(* Hash tables *)
let cache_size = if Arith_options.cached then 10000 else 1;;
let max_cache_size = cache_size * 2;;

let my_add h key v =
  if Hashtbl.length h >= max_cache_size then
    let _ = Hashtbl.clear h in
      print_string "Clearing a nat hash table"
  else 
    ();
  Hashtbl.add h key v;;

let le_table = Hashtbl.create cache_size and
    add_table = Hashtbl.create cache_size and
    sub_table = Hashtbl.create cache_size and
    sub_le_table = Hashtbl.create cache_size and
    mul_table = Hashtbl.create cache_size and
    div_table = Hashtbl.create cache_size;;

(* Counters for collecting stats *)
let suc_counter = ref 0 and
    eq0_counter = ref 0 and
    pre_counter = ref 0 and
    gt0_counter = ref 0 and
    lt_counter = ref 0 and
    le_counter = ref 0 and
    add_counter = ref 0 and
    sub_counter = ref 0 and
    sub_le_counter = ref 0 and
    mul_counter = ref 0 and
    div_counter = ref 0 and
    even_counter = ref 0 and
    odd_counter = ref 0;;


(* Clears all cached results *)	
let reset_cache () =
  let clear = Hashtbl.clear in
    clear le_table;
    clear add_table;
    clear sub_table;
    clear sub_le_table;
    clear mul_table;
    clear div_table;;

	
(* Resets all counters *)
let reset_stat () =
  suc_counter := 0;
  eq0_counter := 0;
  pre_counter := 0;
  gt0_counter := 0;
  lt_counter := 0;
  le_counter := 0;
  add_counter := 0;
  sub_counter := 0;
  sub_le_counter := 0;
  mul_counter := 0;
  div_counter := 0;
  even_counter := 0;
  odd_counter := 0;;

(* Prints stats *)
let print_stat () =
  let len = Hashtbl.length in
  let suc_pre_str = sprintf "suc = %d\npre = %d\n" !suc_counter !pre_counter in
  let cmp0_str = sprintf "eq0 = %d\ngt0 = %d\n" !eq0_counter !gt0_counter in
  let lt_str = sprintf "lt = %d\n" !lt_counter in
  let even_odd_str = sprintf "even = %d\nodd = %d\n" !even_counter !odd_counter in
  let le_str = sprintf "le = %d (le_hash = %d)\n" !le_counter (len le_table) in
  let add_str = sprintf "add = %d (add_hash = %d)\n" !add_counter (len add_table) in
  let sub_str = sprintf "sub = %d (sub_hash = %d)\n" !sub_counter (len sub_table) in
  let sub_le_str = sprintf "sub_le = %d (sub_le_hash = %d)\n" !sub_le_counter (len sub_le_table) in
  let mul_str = sprintf "mul = %d (mul_hash = %d)\n" !mul_counter (len mul_table) in
  let div_str = sprintf "div = %d (div_hash = %d)\n" !div_counter (len div_table) in
    print_string (suc_pre_str ^ cmp0_str ^ lt_str ^ even_odd_str ^
		    le_str ^ add_str ^ sub_str ^ sub_le_str ^ mul_str ^ div_str);;

	
(* Note: the standard Hashtbl.hash function works very purely on terms *)
let rec num_tm_hash tm =
  if is_comb tm then
    let b_tm, n_tm = dest_comb tm in
    let str = (fst o dest_const) b_tm in
      str ^ num_tm_hash n_tm
  else
    "";;

let op_tm_hash tm =
  let lhs, tm2 = dest_comb tm in
  let tm1 = rand lhs in
    num_tm_hash tm1 ^ "x" ^ num_tm_hash tm2;;


let tm1_tm2_hash tm1 tm2 =
  num_tm_hash tm1 ^ "x" ^ num_tm_hash tm2;;


(* SUC *)
let raw_suc_conv_hash tm = 
  let _ = suc_counter := !suc_counter + 1 in
(*  let _ = suc_list := tm :: !suc_list in *)
    Arith_hash.raw_suc_conv_hash tm;;

(* x = 0 *)
let raw_eq0_hash_conv tm = 
  let _ = eq0_counter := !eq0_counter + 1 in
(*  let _ = eq0_list := tm :: !eq0_list in *)
    Arith_hash.raw_eq0_hash_conv tm;;

(* PRE *)
let raw_pre_hash_conv tm = 
  let _ = pre_counter := !pre_counter + 1 in
    Arith_hash.raw_pre_hash_conv tm;;

(* x > 0 *)
let raw_gt0_hash_conv tm = 
  let _ = gt0_counter := !gt0_counter + 1 in
    Arith_hash.raw_gt0_hash_conv tm;;

(* x < y *)
let raw_lt_hash_conv tm = 
  let _ = lt_counter := !lt_counter + 1 in
    Arith_hash.raw_lt_hash_conv tm;;

(* x <= y *)
let raw_le_hash_conv tm = 
  let _ = le_counter := !le_counter + 1 in
  let hash = op_tm_hash tm in
    try
      Hashtbl.find le_table hash
    with Not_found ->
      let result = Arith_hash.raw_le_hash_conv tm in
      let _ = my_add le_table hash result in
	result;;

(* x + y *)
let raw_add_conv_hash tm = 
  let _ = add_counter := !add_counter + 1 in
  let hash = op_tm_hash tm in
    try
      Hashtbl.find add_table hash
    with Not_found ->
      let result = Arith_hash.raw_add_conv_hash tm in
      let _ = my_add add_table hash result in
	result;;

(* x - y *)
let raw_sub_hash_conv tm = 
  let _ = sub_counter := !sub_counter + 1 in
  let hash = op_tm_hash tm in
    try
      Hashtbl.find sub_table hash
    with Not_found ->
      let result = Arith_hash.raw_sub_hash_conv tm in
      let _ = my_add sub_table hash result in
	result;;

let raw_sub_and_le_hash_conv tm1 tm2 = 
  let _ = sub_le_counter := !sub_le_counter + 1 in
  let hash = tm1_tm2_hash tm1 tm2 in
    try
      Hashtbl.find sub_le_table hash
    with Not_found ->
      let result = Arith_hash.raw_sub_and_le_hash_conv tm1 tm2 in
      let _ = my_add sub_le_table hash result in
	result;;

(* x * y *)
let raw_mul_conv_hash tm = 
  let _ = mul_counter := !mul_counter + 1 in
  let hash = op_tm_hash tm in
    try
      Hashtbl.find mul_table hash
    with Not_found ->
      let result = Arith_hash.raw_mul_conv_hash tm in
      let _ = my_add mul_table hash result in
	result;;

(* x / y *)
let raw_div_hash_conv tm = 
  let _ = div_counter := !div_counter + 1 in
  let hash = op_tm_hash tm in
    try
      Hashtbl.find div_table hash
    with Not_found ->
      let result = Arith_hash.raw_div_hash_conv tm in
      let _ = my_add div_table hash result in
	result;;

(* EVEN, ODD *)
let raw_even_hash_conv tm = 
  let _ = even_counter := !even_counter + 1 in
    Arith_hash.raw_even_hash_conv tm;;

let raw_odd_hash_conv tm = 
  let _ = odd_counter := !odd_counter + 1 in
    Arith_hash.raw_odd_hash_conv tm;;

end;;