module Sections = struct

(* Basic commands for working with the goal stack *)
(* b() from tactics.ml *)
let revert_proof_step() = 
  let l = !current_goalstack in
    if length l = 1 then failwith "Can't back up any more" else
      current_goalstack := tl l;
    !current_goalstack;;


(* A flag for fast proof loading (using mk_thm) *)
let fast_load_flag = ref false;;

(* Section variables, hypotheses (with labels), implicit types, and auxiliary lemmas *)
type section_info = 
{
  vars : term list;
  hyps : (string * term) list;
  types : (string * hol_type) list;
  lemmas : (string * thm) list;
};;


let empty_section : section_info = {vars = []; hyps = []; types = []; lemmas = []};;

let section_stack = ref ([] : (string * section_info ref) list);;


(* Begins a new section *)
let begin_section name =
  let sections = !section_stack in
    if can (C assoc sections) name then
      failwith ("Section " ^ name ^ " is already active")
    else
      let sections = (name, ref empty_section) :: sections in
      section_stack := sections;;


(* Ends the active section *)
let end_section name =
  let sections = !section_stack in
    if sections = [] then
      failwith "end_section: No open sections"
    else
      let last_name, _ = hd sections in
	if Pervasives.compare last_name name <> 0 then
	  failwith ("The last open section is " ^ last_name)
	else
	  section_stack := tl sections;;


(* Ends all sections *)
let end_all_sections () =
  section_stack := [];;


(* Returns all section variables in the current section *)
let current_section_vars () =
  if !section_stack = [] then []
  else
    !((snd o hd) !section_stack).vars;;


(* Returns all hypotheses in the current section *)
let current_section_hyps () =
  if !section_stack = [] then []
  else
    !((snd o hd) !section_stack).hyps;;


(* Returns all section variables from all sections *)
let section_vars () : term list =
  let vars = map (fun (_, s) -> !s.vars) !section_stack in
    List.concat vars;;


(* Returns all implicit types from all sections *)
let section_types () : (string * hol_type) list =
  let types = map (fun (_, s) -> !s.types) !section_stack in
    List.concat types;;


(* Returns all hypotheses from all sections *)
let section_hyps () : (string * term) list =
  let hyps = map (fun (_, s) -> !s.hyps) !section_stack in
    List.concat hyps;;


(* Returns all lemmas from all sections *)
let section_lemmas () : (string * thm) list =
  let lemmas = map (fun (_, s) -> !s.lemmas) !section_stack in
    List.concat lemmas;;


(* Returns names of all section lemmas and hypotheses *)
let section_labels () : string list =
  let hyp_names = map fst (section_hyps()) and
      lemma_names = map fst (section_lemmas()) in
    hyp_names @ lemma_names;;


(* Instantiates types of section variables in the term *)
let inst_section_vars tm =
  let s_vars = map dest_var (section_vars()) in
  let find_var (name, ty) =
    try (assoc name s_vars, ty)
    with Failure _ -> (bool_ty, bool_ty) in
  let inst_var (name, ty) tm =
    let ty_dst, ty_src = find_var (name, ty) in
      try (inst (type_match ty_src ty_dst []) tm)
      with Failure _ -> 
	failwith ("Section variable " ^ name ^ 
		    " has type " ^ string_of_type ty_dst) in
  let f_vars = map dest_var (frees tm) in
    itlist inst_var f_vars tm;;


(* Instantiates implicit types in the given term *)
(* (free variables and top generalized variables are considered in the term) *)
let inst_section_types tm = 
  let s_types = section_types() in
  let find_type tm =
    let name, ty = dest_var tm in
    try (assoc name s_types, ty) with Failure _ -> (bool_ty, bool_ty) in
  let f_vars = frees tm in
  let g_vars, _ = strip_forall tm in
  let ty_dst, ty_src = unzip (map find_type (g_vars @ f_vars)) in
  let ty_inst = itlist2 type_match ty_src ty_dst [] in
    inst ty_inst tm;;


(* Checks if the term contains any free variables 
   which are not section variables *)
let check_section_term tm =
  let f_vars = frees tm in
  if !section_stack = [] then
    if f_vars <> [] then
      let str = String.concat ", " (map string_of_term f_vars) in
	failwith ("Free variables: " ^ str)
    else ()
  else
    let s_vars = section_vars() in
    let vars = subtract f_vars s_vars in
      if vars <> [] then
	let str = String.concat ", " (map string_of_term vars) in
	  failwith ("Free variables: " ^ str)
      else ();;
      

(* Adds the given variable to the active section *)
let add_section_var var =
  let sections = !section_stack in
    if sections = [] then
      failwith "add_section_var: No open sections"
    else
      let s_var = section_vars() in
      let var_name, _ = dest_var var in
	if can (C assoc (map dest_var s_var)) var_name then
	  failwith ("A variable with the name "^var_name^" is already defined")
	else
	  let section = (snd o hd) sections in
	    section := {!section with vars = var :: !section.vars};;


(* Adds the given implicit type to the active section *)
let add_section_type tm =
  let sections = !section_stack in
    if sections = [] then
      failwith "add_section_type: No open sections"
    else
      let s_types = section_types() in
      let var_name, ty = dest_var tm in
	if can (C assoc s_types) var_name then
	  failwith ("An implicit type for the variable "^var_name^" is already defined")
	else
	  let section = (snd o hd) sections in
	    section := {!section with types = (var_name, ty) :: !section.types};;

(* Adds the given lemma to the active section *)
let add_section_lemma name th =
  let sections = !section_stack in
    if sections = [] then
      failwith "add_section_lemma: No open sections"
    else
      let labels = section_labels() in
	if mem name labels then
	  failwith ("A lemma (or hypothesis) with the name " ^ name ^ " is already defined")
	else
	  let section = (snd o hd) sections in
	    section := {!section with lemmas = (name, th) :: !section.lemmas};;


(* Adds the given hypothesis (term) to the active section *)
let add_section_hyp label hyp =
  let sections = !section_stack in
    if sections = [] then
      failwith "add_section_hyp: No open sections"
    else
      let labels = section_labels() in
	if mem label labels then
	  failwith ("A hypothesis (or lemma) with the label " ^ label ^ " is already defined")
	else
	  let hyp0 = inst_section_vars hyp in
	  let hyp1 = inst_section_types hyp0 in
	  if type_of hyp1 <> bool_ty then
	    failwith "A boolean term is expected"
	  else
	    let section = (snd o hd) sections in
	      check_section_term hyp1;
	      section := {!section with hyps = (label, hyp1) :: !section.hyps};;


(* Removes the given variable from the active section *)
let remove_section_var var_name =
  let sections = !section_stack in
    if sections = [] then
      failwith "remove_section_var: No open sections"
    else
      let section = (snd o hd) sections in
      let new_vars = filter (fun var -> (fst o dest_var) var <> var_name) !section.vars in
	section := {!section with vars = new_vars};;


(* Removes the given implicit type from the active section *)
let remove_section_type type_name =
  let sections = !section_stack in
    if sections = [] then
      failwith "remove_section_type: No open sections"
    else
      let section = (snd o hd) sections in
      let new_types = filter (fun name, _ -> name <> type_name) !section.types in
	section := {!section with types = new_types};;


(* Removes the given lemma from the active section *)
let remove_section_lemma lemma_name =
  let sections = !section_stack in
    if sections = [] then
      failwith "remove_section_lemma: No open sections"
    else
      let section = (snd o hd) sections in
      let new_lemmas = filter (fun name, _ -> name <> lemma_name) !section.lemmas in
	section := {!section with lemmas = new_lemmas};;


(* Removes the given assumption from the active section *)
let remove_section_hyp label =
  let sections = !section_stack in
    if sections = [] then
      failwith "remove_section_hyp: No open sections"
    else
      let section = (snd o hd) sections in
      let new_hyps = filter (fun name, _ -> name <> label) !section.hyps in
	section := {!section with hyps = new_hyps};;


(* Prepares a goal term *)
let prepare_goal_term tm =
  if !section_stack = [] then (check_section_term tm; tm)
  else
    let tm0 = inst_section_vars tm in
    let tm1 = inst_section_types tm0 in
    let s_hyps = map snd (section_hyps()) in
    let r = itlist (curry mk_imp) s_hyps tm1 in
      check_section_term r; r;;


(* Prepares a goal term and an initial tactic *)
let prepare_section_proof names tm =
  let f_vars = map dest_var (frees tm) in
  let find_type var_name = 
    try assoc var_name f_vars with Failure _ -> failwith ("Unused variable: " ^ var_name) in
  let g_vars = map (fun name -> mk_var (name, find_type name)) names in
  let g_tm = list_mk_forall (g_vars, tm) in
  let tm0 = prepare_goal_term g_tm in
  let hyp_names = map fst (section_hyps()) in
  let lemmas = section_lemmas() in
  let gen_tac = REPLICATE_TAC (length g_vars) GEN_TAC in
  let disch_tac = itlist (fun name tac -> DISCH_THEN (LABEL_TAC name) THEN tac) hyp_names ALL_TAC in
  let assume_tac = itlist (fun (name,lemma) tac -> LABEL_TAC name lemma THEN tac) lemmas ALL_TAC in
    tm0, assume_tac THEN disch_tac THEN gen_tac;;


(* Starts a proof of the goal using section hypotheses *)
let start_section_proof names tm =
  let tm0, tac0 = prepare_section_proof names tm in
  let _ = set_goal([], tm0) in
    refine (by (VALID tac0));;


(* Returns the final theorem *)
let end_section_proof () =
  let th = top_thm() in
  let hyps = section_hyps() in
    itlist (fun _ th -> UNDISCH th) hyps th;;


(* Proves a lemma using section hypotheses and variables *)
let section_proof names tm tac_list =
  let tm0, tac0 = prepare_section_proof names tm in
  let gstate = mk_goalstate ([], tm0) in
  let tac_list1 =
    if !fast_load_flag then
      [fun g -> ACCEPT_TAC(mk_thm([], snd g)) g]
    else
      tac_list in
  let _, sgs, just = rev_itlist by (tac0 :: tac_list1) gstate in
  let th0 =
    if sgs = [] then just null_inst []
    else failwith "section_proof: unsolved goals" in
  let hyps = section_hyps() in
    itlist (fun _ th -> UNDISCH th) hyps th0;;

    
(* Discharges all assumptions and generalizes all section variables *)
let finalize_theorem th =
  let hyps = map snd (current_section_hyps()) in
  let th_hyps = hyp th in
  let hyps0 = intersect hyps th_hyps in
  let s_vars = current_section_vars() in
  let th1 = rev_itlist (fun hyp th -> DISCH hyp th) hyps0 th in
  let f_vars = frees (concl th1) in
  let vars = intersect f_vars s_vars in
    itlist (fun var th -> GEN var th) vars th1;;
    
end;;