(* ========================================================================= *)
(* Representation of primes == 1 (mod 4) as sum of 2 squares.                *)
(* ========================================================================= *)

needs "Library/prime.ml";;

prioritize_num();;

(* ------------------------------------------------------------------------- *)
(* Definition of involution and various basic lemmas.                        *)
(* ------------------------------------------------------------------------- *)

let involution = new_definition
  `involution f s = !x. x IN s ==> f(x) IN s /\ (f(f(x)) = x)`;;
let INVOLUTION_IMAGE = 
prove (`!f s. involution f s ==> (IMAGE f s = s)`,
REWRITE_TAC[involution; EXTENSION; IN_IMAGE] THEN MESON_TAC[]);;
let INVOLUTION_DELETE = 
prove (`involution f s /\ a IN s /\ (f a = a) ==> involution f (s DELETE a)`,
REWRITE_TAC[involution; IN_DELETE] THEN MESON_TAC[]);;
let INVOLUTION_STEPDOWN = 
prove (`involution f s /\ a IN s ==> involution f (s DIFF {a, (f a)})`,
REWRITE_TAC[involution; IN_DIFF; IN_INSERT; NOT_IN_EMPTY] THEN MESON_TAC[]);;
let INVOLUTION_NOFIXES = 
prove (`involution f s ==> involution f {x | x IN s /\ ~(f x = x)}`,
REWRITE_TAC[involution; IN_ELIM_THM] THEN MESON_TAC[]);;
let INVOLUTION_SUBSET = 
prove (`!f s t. involution f s /\ (!x. x IN t ==> f(x) IN t) /\ t SUBSET s ==> involution f t`,
REWRITE_TAC[involution; SUBSET] THEN MESON_TAC[]);;
(* ------------------------------------------------------------------------- *) (* Involution with no fixpoints can only occur on finite set of even card *) (* ------------------------------------------------------------------------- *)
let INVOLUTION_EVEN_STEP = 
prove (`FINITE(s) /\ involution f s /\ (!x:A. x IN s ==> ~(f x = x)) /\ a IN s ==> FINITE(s DIFF {a, (f a)}) /\ involution f (s DIFF {a, (f a)}) /\ (!x:A. x IN (s DIFF {a, (f a)}) ==> ~(f x = x)) /\ (CARD s = CARD(s DIFF {a, (f a)}) + 2)`,
SIMP_TAC[FINITE_DIFF; INVOLUTION_STEPDOWN; IN_DIFF] THEN STRIP_TAC THEN SUBGOAL_THEN `s = (a:A) INSERT (f a) INSERT (s DIFF {a, (f a)})` MP_TAC THENL [REWRITE_TAC[EXTENSION; IN_INSERT; IN_DIFF; NOT_IN_EMPTY] THEN ASM_MESON_TAC[involution]; ALL_TAC] THEN DISCH_THEN(fun th -> GEN_REWRITE_TAC (LAND_CONV o RAND_CONV) [th]) THEN ASM_SIMP_TAC[CARD_CLAUSES; FINITE_DIFF; FINITE_INSERT] THEN ASM_SIMP_TAC[IN_INSERT; IN_DIFF; NOT_IN_EMPTY] THEN ARITH_TAC);;
let INVOLUTION_EVEN_INDUCT = 
prove (`!n s. FINITE(s) /\ (CARD s = n) /\ involution f s /\ (!x:A. x IN s ==> ~(f x = x)) ==> EVEN(CARD s)`,
MATCH_MP_TAC num_WF THEN GEN_TAC THEN DISCH_TAC THEN GEN_TAC THEN ASM_CASES_TAC `s:A->bool = {}` THEN ASM_REWRITE_TAC[CARD_CLAUSES; ARITH] THEN FIRST_X_ASSUM(MP_TAC o GEN_REWRITE_RULE RAND_CONV [EXTENSION]) THEN REWRITE_TAC[NOT_IN_EMPTY; NOT_FORALL_THM] THEN DISCH_THEN(X_CHOOSE_THEN `a:A` STRIP_ASSUME_TAC) THEN STRIP_TAC THEN FIRST_X_ASSUM(MP_TAC o SPEC `CARD(s DIFF {a:A, (f a)})`) THEN REWRITE_TAC[RIGHT_IMP_FORALL_THM] THEN DISCH_THEN(MP_TAC o SPEC `s DIFF {a:A, (f a)}`) THEN MP_TAC INVOLUTION_EVEN_STEP THEN ASM_REWRITE_TAC[] THEN STRIP_TAC THEN ASM_REWRITE_TAC[ARITH_RULE `n < n + 2`] THEN SIMP_TAC[EVEN_ADD; ARITH]);;
let INVOLUTION_EVEN = 
prove (`!s. FINITE(s) /\ involution f s /\ (!x:A. x IN s ==> ~(f x = x)) ==> EVEN(CARD s)`,
MESON_TAC[INVOLUTION_EVEN_INDUCT]);;
(* ------------------------------------------------------------------------- *) (* So an involution with exactly one fixpoint has odd card domain. *) (* ------------------------------------------------------------------------- *)
let INVOLUTION_FIX_ODD = 
prove (`FINITE(s) /\ involution f s /\ (?!a:A. a IN s /\ (f a = a)) ==> ODD(CARD s)`,
REWRITE_TAC[EXISTS_UNIQUE_DEF] THEN STRIP_TAC THEN SUBGOAL_THEN `s = (a:A) INSERT (s DELETE a)` SUBST1_TAC THENL [REWRITE_TAC[EXTENSION; IN_INSERT; IN_DELETE] THEN ASM_MESON_TAC[]; ALL_TAC] THEN ASM_SIMP_TAC[CARD_CLAUSES; FINITE_DELETE; IN_DELETE; ODD; NOT_ODD] THEN MATCH_MP_TAC INVOLUTION_EVEN THEN ASM_SIMP_TAC[INVOLUTION_DELETE; FINITE_DELETE; IN_DELETE] THEN ASM_MESON_TAC[]);;
(* ------------------------------------------------------------------------- *) (* And an involution on a set of odd finite card must have a fixpoint. *) (* ------------------------------------------------------------------------- *)
let INVOLUTION_ODD = 
prove (`!n s. FINITE(s) /\ involution f s /\ ODD(CARD s) ==> ?a. a IN s /\ (f a = a)`,
REWRITE_TAC[GSYM NOT_EVEN] THEN MESON_TAC[INVOLUTION_EVEN]);;
(* ------------------------------------------------------------------------- *) (* Consequently, if one involution has a unique fixpoint, other has one. *) (* ------------------------------------------------------------------------- *)
let INVOLUTION_FIX_FIX = 
prove (`!f g s. FINITE(s) /\ involution f s /\ involution g s /\ (?!x. x IN s /\ (f x = x)) ==> ?x. x IN s /\ (g x = x)`,
REPEAT STRIP_TAC THEN MATCH_MP_TAC INVOLUTION_ODD THEN ASM_REWRITE_TAC[] THEN MATCH_MP_TAC INVOLUTION_FIX_ODD THEN ASM_REWRITE_TAC[]);;
(* ------------------------------------------------------------------------- *) (* Formalization of Zagier's "one-sentence" proof over the natural numbers. *) (* ------------------------------------------------------------------------- *)
let zset = new_definition
  `zset(a) = {(x,y,z) | x EXP 2 + 4 * y * z = a}`;;
let zag = new_definition
  `zag(x,y,z) =
        if x + z < y then (x + 2 * z,z,y - (x + z))
        else if x < 2 * y then (2 * y - x, y, (x + z) - y)
        else (x - 2 * y,(x + z) - y, y)`;;
let tag = new_definition
  `tag((x,y,z):num#num#num) = (x,z,y)`;;
let ZAG_INVOLUTION_GENERAL = 
prove (`0 < x /\ 0 < y /\ 0 < z ==> (zag(zag(x,y,z)) = (x,y,z))`,
REWRITE_TAC[zag] THEN REPEAT(COND_CASES_TAC THEN ASM_REWRITE_TAC[]) THEN REWRITE_TAC[zag] THEN REPEAT(COND_CASES_TAC THEN ASM_REWRITE_TAC[]) THEN REWRITE_TAC[PAIR_EQ] THEN POP_ASSUM_LIST(MP_TAC o end_itlist CONJ) THEN ARITH_TAC);;
let IN_TRIPLE = 
prove (`(a,b,c) IN {(x,y,z) | P x y z} <=> P a b c`,
REWRITE_TAC[IN_ELIM_THM; PAIR_EQ] THEN MESON_TAC[]);;
let PRIME_SQUARE = 
prove (`!n. ~prime(n * n)`,
GEN_TAC THEN ASM_CASES_TAC `n = 0` THEN ASM_REWRITE_TAC[PRIME_0; MULT_CLAUSES] THEN REWRITE_TAC[prime; NOT_FORALL_THM; DE_MORGAN_THM] THEN ASM_CASES_TAC `n = 1` THEN ASM_REWRITE_TAC[ARITH] THEN DISJ2_TAC THEN EXISTS_TAC `n:num` THEN ASM_SIMP_TAC[DIVIDES_LMUL; DIVIDES_REFL] THEN GEN_REWRITE_TAC (RAND_CONV o LAND_CONV) [ARITH_RULE `n = n * 1`] THEN ASM_SIMP_TAC[EQ_MULT_LCANCEL]);;
let PRIME_4X = 
prove (`!n. ~prime(4 * n)`,
GEN_TAC THEN REWRITE_TAC[prime; NOT_FORALL_THM; DE_MORGAN_THM] THEN DISJ2_TAC THEN EXISTS_TAC `2` THEN SUBST1_TAC(SYM(NUM_REDUCE_CONV `2 * 2`)) THEN ASM_SIMP_TAC[GSYM MULT_ASSOC; DIVIDES_RMUL; DIVIDES_REFL; ARITH_EQ] THEN ASM_CASES_TAC `n = 0` THEN POP_ASSUM MP_TAC THEN ARITH_TAC);;
let PRIME_XYZ_NONZERO = 
prove (`prime(x EXP 2 + 4 * y * z) ==> 0 < x /\ 0 < y /\ 0 < z`,
CONV_TAC CONTRAPOS_CONV THEN REWRITE_TAC[DE_MORGAN_THM; ARITH_RULE `~(0 < x) = (x = 0)`] THEN DISCH_THEN(REPEAT_TCL DISJ_CASES_THEN SUBST1_TAC) THEN REWRITE_TAC[EXP_2; MULT_CLAUSES; ADD_CLAUSES; PRIME_SQUARE; PRIME_4X]);;
let ZAG_INVOLUTION = 
prove (`!p. prime(p) ==> involution zag (zset(p))`,
REPEAT STRIP_TAC THEN REWRITE_TAC[involution; FORALL_PAIR_THM] THEN MAP_EVERY X_GEN_TAC [`x:num`; `y:num`; `z:num`] THEN REWRITE_TAC[zset; IN_TRIPLE] THEN DISCH_THEN(SUBST_ALL_TAC o SYM) THEN CONJ_TAC THENL [REWRITE_TAC[zag] THEN REPEAT COND_CASES_TAC THEN ASM_REWRITE_TAC[IN_TRIPLE] THEN RULE_ASSUM_TAC(REWRITE_RULE[NOT_LT]) THEN ASM_SIMP_TAC[GSYM INT_OF_NUM_EQ; GSYM INT_OF_NUM_ADD; EXP_2; GSYM INT_OF_NUM_MUL; GSYM INT_OF_NUM_SUB; LT_IMP_LE] THEN INT_ARITH_TAC; MATCH_MP_TAC ZAG_INVOLUTION_GENERAL THEN ASM_MESON_TAC[PRIME_XYZ_NONZERO]]);;
let TAG_INVOLUTION = 
prove (`!a. involution tag (zset a)`,
REWRITE_TAC[involution; tag; zset; FORALL_PAIR_THM] THEN REWRITE_TAC[IN_TRIPLE] THEN REWRITE_TAC[MULT_AC]);;
let ZAG_LEMMA = 
prove (`(zag(x,y,z) = (x,y,z)) ==> (y = x)`,
REWRITE_TAC[zag; INT_POW_2] THEN REPEAT(COND_CASES_TAC THEN ASM_SIMP_TAC[PAIR_EQ]) THEN POP_ASSUM_LIST(MP_TAC o end_itlist CONJ) THEN ARITH_TAC);;
let ZSET_BOUND = 
prove (`0 < y /\ 0 < z /\ (x EXP 2 + 4 * y * z = p) ==> x <= p /\ y <= p /\ z <= p`,
REPEAT GEN_TAC THEN STRIP_TAC THEN FIRST_X_ASSUM(SUBST1_TAC o SYM) THEN CONJ_TAC THENL [MESON_TAC[EXP_2; LE_SQUARE_REFL; ARITH_RULE `(a <= b ==> a <= b + c)`]; CONJ_TAC THEN MATCH_MP_TAC(ARITH_RULE `y <= z ==> y <= x + z`) THENL [GEN_REWRITE_TAC (RAND_CONV o RAND_CONV) [MULT_SYM]; ALL_TAC] THEN REWRITE_TAC[ARITH_RULE `y <= 4 * a * y <=> 1 * y <= (4 * a) * y`] THEN ASM_REWRITE_TAC[LE_MULT_RCANCEL] THEN ASM_SIMP_TAC[ARITH_RULE `0 < a ==> 1 <= 4 * a`]]);;
let ZSET_FINITE = 
prove (`!p. prime(p) ==> FINITE(zset p)`,
GEN_TAC THEN DISCH_TAC THEN MP_TAC(SPEC `p + 1` FINITE_NUMSEG_LT) THEN DISCH_THEN(fun th -> MP_TAC(funpow 2 (MATCH_MP FINITE_PRODUCT o CONJ th) th)) THEN MATCH_MP_TAC(REWRITE_RULE[TAUT `a /\ b ==> c <=> b ==> a ==> c`] FINITE_SUBSET) THEN REWRITE_TAC[zset; SUBSET; FORALL_PAIR_THM; IN_TRIPLE] THEN MAP_EVERY X_GEN_TAC [`x:num`; `y:num`; `z:num`] THEN REWRITE_TAC[IN_ELIM_THM; EXISTS_PAIR_THM; PAIR_EQ] THEN REWRITE_TAC[ARITH_RULE `x < p + 1 <=> x <= p`; PAIR_EQ] THEN DISCH_TAC THEN MAP_EVERY EXISTS_TAC [`x:num`; `y:num`; `z:num`] THEN ASM_REWRITE_TAC[] THEN REWRITE_TAC[RIGHT_AND_EXISTS_THM] THEN MAP_EVERY EXISTS_TAC [`y:num`; `z:num`] THEN REWRITE_TAC[] THEN ASM_MESON_TAC[ZSET_BOUND; PRIME_XYZ_NONZERO]);;
let SUM_OF_TWO_SQUARES = 
prove (`!p k. prime(p) /\ (p = 4 * k + 1) ==> ?x y. p = x EXP 2 + y EXP 2`,
SIMP_TAC[] THEN REPEAT STRIP_TAC THEN SUBGOAL_THEN `?t. t IN zset(p) /\ (tag(t) = t)` MP_TAC THENL [ALL_TAC; REWRITE_TAC[LEFT_IMP_EXISTS_THM; FORALL_PAIR_THM; tag; PAIR_EQ] THEN REWRITE_TAC[zset; IN_TRIPLE; EXP_2] THEN ASM_MESON_TAC[ARITH_RULE `4 * x * y = (2 * x) * (2 * y)`]] THEN MATCH_MP_TAC INVOLUTION_FIX_FIX THEN EXISTS_TAC `zag` THEN ASM_SIMP_TAC[ZAG_INVOLUTION; TAG_INVOLUTION; ZSET_FINITE] THEN REWRITE_TAC[EXISTS_UNIQUE_ALT] THEN EXISTS_TAC `1,1,k:num` THEN REWRITE_TAC[FORALL_PAIR_THM] THEN MAP_EVERY X_GEN_TAC [`x:num`; `y:num`; `z:num`] THEN EQ_TAC THENL [ALL_TAC; DISCH_THEN(SUBST1_TAC o SYM) THEN REWRITE_TAC[zset; zag; IN_TRIPLE; ARITH] THEN REWRITE_TAC[MULT_CLAUSES; ARITH_RULE `~(1 + k < 1)`; PAIR_EQ] THEN ARITH_TAC] THEN REWRITE_TAC[zset; IN_TRIPLE] THEN STRIP_TAC THEN FIRST_ASSUM(SUBST_ALL_TAC o MATCH_MP ZAG_LEMMA) THEN UNDISCH_TAC `x EXP 2 + 4 * x * z = 4 * k + 1` THEN REWRITE_TAC[EXP_2; ARITH_RULE `x * x + 4 * x * z = x * (4 * z + x)`] THEN DISCH_THEN(ASSUME_TAC o SYM) THEN UNDISCH_TAC `prime p` THEN ASM_REWRITE_TAC[] THEN REWRITE_TAC[prime] THEN DISCH_THEN(CONJUNCTS_THEN2 ASSUME_TAC (MP_TAC o SPEC `x:num`)) THEN SIMP_TAC[DIVIDES_RMUL; DIVIDES_REFL] THEN DISCH_THEN(DISJ_CASES_THEN2 SUBST_ALL_TAC MP_TAC) THENL [UNDISCH_TAC `4 * k + 1 = 1 * (4 * z + 1)` THEN REWRITE_TAC[MULT_CLAUSES; PAIR_EQ] THEN ARITH_TAC; ONCE_REWRITE_TAC[ARITH_RULE `(a = a * b) = (a * b = a * 1)`] THEN ASM_SIMP_TAC[EQ_MULT_LCANCEL] THEN STRIP_TAC THENL [UNDISCH_TAC `4 * k + 1 = x * (4 * z + x)` THEN ASM_REWRITE_TAC[MULT_CLAUSES; ADD_EQ_0; ARITH_EQ]; UNDISCH_TAC `4 * z + x = 1` THEN REWRITE_TAC[PAIR_EQ] THEN ASM_CASES_TAC `z = 0` THENL [ALL_TAC; UNDISCH_TAC `~(z = 0)` THEN ARITH_TAC] THEN UNDISCH_TAC `4 * k + 1 = x * (4 * z + x)` THEN ASM_REWRITE_TAC[MULT_CLAUSES; ADD_CLAUSES] THEN ASM_CASES_TAC `x = 1` THEN ASM_REWRITE_TAC[] THEN REWRITE_TAC[MULT_CLAUSES] THEN ARITH_TAC]]);;