(*
copied from repository 'http://seanmcl-ocaml-lib.googlecode.com/svn'
*)
(* 
   Print the types of the atoms of terms, thms and the goal.
*) 
module Print_types 
  : sig 
      val print_goal_types: unit -> unit
      val print_thm_types: thm -> unit
      val print_term_types: term -> unit
      val print_goal_var_overload: unit -> unit

      val get_term_types: term -> (string*hol_type) list
      val get_const_types: term -> (string*hol_type) list

      (* I've elided the common symbols from the list
         to avoid clutter.  You can add and remove names
         of constants from this list with (un)suppress *)
      val suppress: string -> unit
      val unsuppress: string -> unit

    end =
struct

  let suppressed = ref 
    ["==>";"?";"!";"/\\";"\\/";",";"~";"APPEND";"CONS";"HD";"LAST";
     "NIL";"=";"real_lt";"real_gt";"real_le";"real_ge";"BIT0";"BIT1";"NUMERAL";
     "DECIMAL";"real_sub";"real_add";"real_neg";
     "real_of_num";"_0";"_1";"real_div";"real_mul";"real_pow";"COND";"LET";"LET_END"] 

  let suppress s = suppressed := s :: !suppressed 

  let unsuppress s = suppressed := List.filter ((!=) s) (!suppressed)

  let rec get_term_types tm = 
    match tm with       
        Var(s,t) -> if mem s !suppressed then [] else [(s,t)] 
      | Const(s,t) -> if mem s !suppressed then [] else [(s,t)] 
      | Comb (t1,t2) -> get_term_types t1 @ get_term_types t2 
      | Abs (t1,t2) -> get_term_types t1 @ get_term_types t2 

  let rec get_const_types tm = 
    match tm with       
        Var(s,t) -> []
      | Const(s,t) -> if mem s !suppressed then [] else [(s,t)] 
      | Comb (t1,t2) -> get_const_types t1 @ get_const_types t2 
      | Abs (t1,t2) -> get_const_types t1 @ get_const_types t2 

  let rec get_var_types = function
        Var(s,t) -> [(s,t)]
      | Const(s,t) -> []
      | Comb (t1,t2) -> get_var_types t1 @ get_var_types t2 
      | Abs (t1,t2) -> get_var_types t1 @ get_var_types t2 

  let print_atom_type : string * hol_type -> unit =
    fun (s,t) -> 
      begin 
        print_string ("(\"" ^ s ^ "\", ");
        print_type t; 
        print_string ")\n" 
      end

  let setify_types tm = ((sort (<)) o setify o get_term_types) tm 

  let print_term_types = List.iter print_atom_type o setify_types 

  let print_thm_types tm = print_term_types (concl tm) 

  let print_goal_types() =
    let (asms,g) = top_goal() in
    let tms = g::asms in
    let tm = end_itlist (curry mk_conj) tms in
      (print_term_types tm)

  let print_goal_var_overload() =
    let (asms,g) = top_goal() in
    let tms = g::asms in
    let tm = end_itlist (curry mk_conj) tms in
    let vt = ((sort (<)) o setify) (get_var_types tm) in
    let ct = ((sort (<)) o setify) (get_const_types tm @ vt) in
    let ft x = filter (fun t -> (fst t = fst x) && not(snd t = snd x)) ct in 
    let ov = ((sort (<)) o setify) (List.flatten ( map ft vt )) in
      if (List.length ov >0) then 
	(print_string "Warning: type-overloaded variables\n";
	 List.iter print_atom_type ov)
      else ();;


end;;

open Print_types;;